Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
......@@ -685,9 +685,9 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
......@@ -725,7 +725,7 @@ def main():
for i in range(model.num_layers):
info_str += " {:.2f}".format(100 * each_layer_results[i])
logger.info(info_str)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
result = {k + "_{}".format(global_step): v for k, v in result.items()}
results.update(result)
return results
......
......@@ -27,7 +27,7 @@ from utils import logger
def _quantize(x, bins):
bins = copy.deepcopy(bins)
bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
quantized = [bisect.bisect_right(bins, y) for y in x]
return quantized
......
......@@ -850,9 +850,9 @@ def main():
logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
......@@ -865,7 +865,7 @@ def main():
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result)
logger.info("Results: {}".format(results))
......
......@@ -247,9 +247,12 @@ class Trainer:
lr = self.scheduler_fn(state_step - 1)
eval_loss = self.evaluate(state, val_dataset)
logging_dict = dict(
step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item()
)
logging_dict = {
"step": state_step.item(),
"eval_loss": eval_loss.item(),
"tr_loss": tr_loss,
"lr": lr.item(),
}
tqdm.write(str(logging_dict))
self.logger.log(logging_dict, commit=True)
......
......@@ -144,9 +144,9 @@ def main():
predictions = expand_to_aliases(example["output"])
# some preprocessing to both prediction and answer
answers = set(["".join(a.split()) for a in answers])
predictions = set(["".join(p.split()) for p in predictions])
predictions = set([s for s in predictions if s not in ["``", "''", "`", "'"]])
answers = {"".join(a.split()) for a in answers}
predictions = {"".join(p.split()) for p in predictions}
predictions = {s for s in predictions if s not in ["``", "''", "`", "'"]}
# if there is a common element, it's a exact match
example["match"] = len(list(answers & predictions)) > 0
......
......@@ -314,12 +314,12 @@ if __name__ == "__main__":
data = data["train" if PROCESS_TRAIN == "true" else "validation"]
fn_kwargs = dict(
tokenizer=tokenizer,
doc_stride=DOC_STRIDE,
max_length=MAX_LENGTH,
assertion=False,
)
fn_kwargs = {
"tokenizer": tokenizer,
"doc_stride": DOC_STRIDE,
"max_length": MAX_LENGTH,
"assertion": False,
}
data = data.map(prepare_inputs, fn_kwargs=fn_kwargs)
data = data.remove_columns(["annotations", "document", "id", "question"])
print(data)
......
......@@ -34,7 +34,7 @@ empty_dict = object()
def _match(qs, ks):
"""Return True if regexes in qs match any window of strings in tuple ks."""
# compile regexes and force complete match
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
qts = tuple((re.compile(x + "$") for x in qs))
for i in range(len(ks) - len(qs) + 1):
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
if matches and all(matches):
......
......@@ -78,7 +78,7 @@ def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_1
)
hits = response["hits"]["hits"]
support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])
res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits]
res_list = [{k: hit["_source"][k] for k in hit["_source"] if k != "passage_text"} for hit in hits]
for r, hit in zip(res_list, hits):
r["passage_id"] = hit["_id"]
r["score"] = hit["_score"]
......@@ -601,7 +601,7 @@ def make_qa_dense_index(
fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
n_batches = math.ceil(passages_dset.num_rows / batch_size)
for i in range(n_batches):
passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]
passages = list(passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"])
reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
fp[i * batch_size : (i + 1) * batch_size] = reps
if i % 50 == 0:
......@@ -634,7 +634,7 @@ def query_qa_dense_index(
D, I = wiki_index.search(q_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc in zip(res_list, D[0]):
r["score"] = float(sc)
......@@ -650,7 +650,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
]
all_res_lists = []
for res_passages, dl in zip(res_passages_lst, D):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
for r, sc in zip(res_list, dl):
r["score"] = float(sc)
all_res_lists += [res_list[:]]
......@@ -663,7 +663,7 @@ def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki
D, I = wiki_index.search(a_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc, i in zip(res_list, D[0], I[0]):
r["passage_id"] = int(i)
......@@ -680,7 +680,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
]
all_res_lists = []
for res_passages, dl, il in zip(res_passages_lst, D, I):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
for r, sc, i in zip(res_list, dl, il):
r["passage_id"] = int(i)
r["score"] = float(sc)
......
......@@ -61,7 +61,7 @@ class Extract:
assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}"
if subset_list is not None:
with open(os.path.realpath(subset_list)) as f:
self.subset_list = set(map(lambda x: self._vqa_file_split()[0], tryload(f)))
self.subset_list = {self._vqa_file_split()[0] for x in tryload(f)}
else:
self.subset_list = None
......
......@@ -1095,7 +1095,7 @@ class ROIPooler(nn.Module):
Returns:
A tensor of shape(N*B, Channels, output_size, output_size)
"""
x = [v for v in feature_maps.values()]
x = list(feature_maps.values())
num_level_assignments = len(self.level_poolers)
assert len(x) == num_level_assignments and len(boxes) == x[0].size(0)
......
......@@ -554,9 +554,9 @@ def main():
if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
......@@ -566,7 +566,7 @@ def main():
model.load_state_dict(torch.load(checkpoint))
model.to(args.device)
result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
result = {k + "_{}".format(global_step): v for k, v in result.items()}
results.update(result)
return results
......
......@@ -941,9 +941,9 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
......@@ -953,7 +953,7 @@ def main():
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
result = {k + "_{}".format(global_step): v for k, v in result.items()}
results.update(result)
return results
......
......@@ -1109,10 +1109,10 @@ def main():
logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c)
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]
else:
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
......@@ -1129,7 +1129,7 @@ def main():
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result)
logger.info("Results: {}".format(results))
......
......@@ -42,8 +42,8 @@ def _graph_replace_input_with(graph_proto, name, new_name):
def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
inits_with_data = [i for i in model.graph.initializer]
inits = [i for i in model_without_ext.graph.initializer]
inits_with_data = list(model.graph.initializer)
inits = list(model_without_ext.graph.initializer)
for i, ref_i in ind_to_replace:
assert inits_with_data[i].name == inits[i].name
assert inits_with_data[ref_i].name == inits[ref_i].name
......@@ -69,7 +69,7 @@ def remove_dup_initializers(onnx_file_path):
model = onnx.load(os.path.join(model_file_folder, model_file_name))
inits = [i for i in model.graph.initializer]
inits = list(model.graph.initializer)
dup_set = set()
dup_map = {}
......
......@@ -127,11 +127,9 @@ def perturb_past(
_, _, _, curr_length, _ = past[0].shape
if curr_length > window_length and window_length > 0:
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:])
ones_key_val_shape = tuple(past[0].shape[:-2]) + (window_length,) + tuple(past[0].shape[-1:])
zeros_key_val_shape = (
tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
)
zeros_key_val_shape = tuple(past[0].shape[:-2]) + (curr_length - window_length,) + tuple(past[0].shape[-1:])
ones_mask = torch.ones(ones_key_val_shape)
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
......
......@@ -164,11 +164,11 @@ class GenerativeQAModule(BaseTransformer):
self.step_count = 0
self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=prefix or "",
)
self.dataset_kwargs: dict = {
"data_dir": self.hparams.data_dir,
"max_source_length": self.hparams.max_source_length,
"prefix": prefix or "",
}
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
......
......@@ -137,7 +137,7 @@ logger = getLogger(__name__)
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
return list(itertools.chain.from_iterable(summary_ids))
def save_git_info(folder_path: str) -> None:
......
......@@ -162,11 +162,11 @@ class GenerativeQAModule(BaseTransformer):
self.step_count = 0
self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=prefix or "",
)
self.dataset_kwargs: dict = {
"data_dir": self.hparams.data_dir,
"max_source_length": self.hparams.max_source_length,
"prefix": prefix or "",
}
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
......
......@@ -137,7 +137,7 @@ logger = getLogger(__name__)
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
return list(itertools.chain.from_iterable(summary_ids))
def save_git_info(folder_path: str) -> None:
......
......@@ -344,7 +344,7 @@ def create_vocabulary_from_data(
lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
)
vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_set))}
# replace white space with delimiter token
if word_delimiter_token is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment