Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
...@@ -397,7 +397,7 @@ def main(): ...@@ -397,7 +397,7 @@ def main():
# Preprocessing the datasets. # Preprocessing the datasets.
# We need to tokenize input captions and transform the images. # We need to tokenize input captions and transform the images.
def tokenize_captions(examples): def tokenize_captions(examples):
captions = [caption for caption in examples[caption_column]] captions = list(examples[caption_column])
text_inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True) text_inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True)
examples["input_ids"] = text_inputs.input_ids examples["input_ids"] = text_inputs.input_ids
examples["attention_mask"] = text_inputs.attention_mask examples["attention_mask"] = text_inputs.attention_mask
......
...@@ -250,7 +250,7 @@ def main(): ...@@ -250,7 +250,7 @@ def main():
# Prepare label mappings. # Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names labels = dataset["train"].features["labels"].names
label2id, id2label = dict(), dict() label2id, id2label = {}, {}
for i, label in enumerate(labels): for i, label in enumerate(labels):
label2id[label] = str(i) label2id[label] = str(i)
id2label[str(i)] = label id2label[str(i)] = label
......
...@@ -91,7 +91,7 @@ class DataTrainingArguments: ...@@ -91,7 +91,7 @@ class DataTrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
data_files = dict() data_files = {}
if self.train_dir is not None: if self.train_dir is not None:
data_files["train"] = self.train_dir data_files["train"] = self.train_dir
if self.validation_dir is not None: if self.validation_dir is not None:
......
...@@ -104,7 +104,7 @@ class DataTrainingArguments: ...@@ -104,7 +104,7 @@ class DataTrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
data_files = dict() data_files = {}
if self.train_dir is not None: if self.train_dir is not None:
data_files["train"] = self.train_dir data_files["train"] = self.train_dir
if self.validation_dir is not None: if self.validation_dir is not None:
......
...@@ -407,7 +407,7 @@ def main(): ...@@ -407,7 +407,7 @@ def main():
) )
else: else:
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)
n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
......
...@@ -457,14 +457,14 @@ def main(): ...@@ -457,14 +457,14 @@ def main():
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
kwargs = dict( kwargs = {
finetuned_from=model_args.model_name_or_path, "finetuned_from": model_args.model_name_or_path,
tasks="multiple-choice", "tasks": "multiple-choice",
dataset_tags="swag", "dataset_tags": "swag",
dataset_args="regular", "dataset_args": "regular",
dataset="SWAG", "dataset": "SWAG",
language="en", "language": "en",
) }
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub(**kwargs) trainer.push_to_hub(**kwargs)
......
...@@ -430,7 +430,7 @@ def main(): ...@@ -430,7 +430,7 @@ def main():
pixel_values.append(image) pixel_values.append(image)
labels.append(target) labels.append(target)
encoding = dict() encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values) encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels) encoding["labels"] = torch.stack(labels)
...@@ -444,7 +444,7 @@ def main(): ...@@ -444,7 +444,7 @@ def main():
pixel_values.append(image) pixel_values.append(image)
labels.append(target) labels.append(target)
encoding = dict() encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values) encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels) encoding["labels"] = torch.stack(labels)
......
...@@ -441,7 +441,7 @@ def main(): ...@@ -441,7 +441,7 @@ def main():
pixel_values.append(image) pixel_values.append(image)
labels.append(target) labels.append(target)
encoding = dict() encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values) encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels) encoding["labels"] = torch.stack(labels)
...@@ -455,7 +455,7 @@ def main(): ...@@ -455,7 +455,7 @@ def main():
pixel_values.append(image) pixel_values.append(image)
labels.append(target) labels.append(target)
encoding = dict() encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values) encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels) encoding["labels"] = torch.stack(labels)
......
...@@ -349,7 +349,7 @@ def create_vocabulary_from_data( ...@@ -349,7 +349,7 @@ def create_vocabulary_from_data(
lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values() lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
) )
vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))} vocab_dict = {v: k for k, v in enumerate(sorted(vocab_set))}
# replace white space with delimiter token # replace white space with delimiter token
if word_delimiter_token is not None: if word_delimiter_token is not None:
......
...@@ -406,12 +406,12 @@ def main(): ...@@ -406,12 +406,12 @@ def main():
): ):
# Some have all caps in their config, some don't. # Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): if sorted(label_name_to_id.keys()) == sorted(label_list):
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
else: else:
logger.warning( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
"\nIgnoring the model labels as a result.", "\nIgnoring the model labels as a result.",
) )
elif data_args.task_name is None and not is_regression: elif data_args.task_name is None and not is_regression:
......
...@@ -339,7 +339,7 @@ def main(): ...@@ -339,7 +339,7 @@ def main():
): ):
# Some have all caps in their config, some don't. # Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): if sorted(label_name_to_id.keys()) == sorted(label_list):
logger.info( logger.info(
f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
"Using it!" "Using it!"
...@@ -348,7 +348,7 @@ def main(): ...@@ -348,7 +348,7 @@ def main():
else: else:
logger.warning( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
"\nIgnoring the model labels as a result.", "\nIgnoring the model labels as a result.",
) )
elif args.task_name is None and not is_regression: elif args.task_name is None and not is_regression:
......
...@@ -386,7 +386,7 @@ def main(): ...@@ -386,7 +386,7 @@ def main():
# Model has labels -> use them. # Model has labels -> use them.
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): if sorted(model.config.label2id.keys()) == sorted(label_list):
# Reorganize `label_list` to match the ordering of the model. # Reorganize `label_list` to match the ordering of the model.
if labels_are_int: if labels_are_int:
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)} label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
...@@ -397,8 +397,8 @@ def main(): ...@@ -397,8 +397,8 @@ def main():
else: else:
logger.warning( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:" f"model labels: {sorted(model.config.label2id.keys())}, dataset labels:"
f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.", f" {sorted(label_list)}.\nIgnoring the model labels as a result.",
) )
# Set the correspondences label/ID inside the model config # Set the correspondences label/ID inside the model config
......
...@@ -425,7 +425,7 @@ def main(): ...@@ -425,7 +425,7 @@ def main():
# Model has labels -> use them. # Model has labels -> use them.
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): if sorted(model.config.label2id.keys()) == sorted(label_list):
# Reorganize `label_list` to match the ordering of the model. # Reorganize `label_list` to match the ordering of the model.
if labels_are_int: if labels_are_int:
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)} label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
...@@ -436,8 +436,8 @@ def main(): ...@@ -436,8 +436,8 @@ def main():
else: else:
logger.warning( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:" f"model labels: {sorted(model.config.label2id.keys())}, dataset labels:"
f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.", f" {sorted(label_list)}.\nIgnoring the model labels as a result.",
) )
# Set the correspondences label/ID inside the model config # Set the correspondences label/ID inside the model config
......
...@@ -727,9 +727,9 @@ def main(): ...@@ -727,9 +727,9 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
) ]
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
...@@ -743,7 +743,7 @@ def main(): ...@@ -743,7 +743,7 @@ def main():
print(f"Evaluation for checkpoint {prefix}") print(f"Evaluation for checkpoint {prefix}")
for patience in patience_list: for patience in patience_list:
result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience) result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) result = {k + "_{}".format(global_step): v for k, v in result.items()}
results.update(result) results.update(result)
return results return results
......
...@@ -54,7 +54,7 @@ class BertAbs(BertAbsPreTrainedModel): ...@@ -54,7 +54,7 @@ class BertAbs(BertAbsPreTrainedModel):
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
if load_bert_pretrained_extractive: if load_bert_pretrained_extractive:
self.bert.model.load_state_dict( self.bert.model.load_state_dict(
dict([(n[11:], p) for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")]), {n[11:]: p for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")},
strict=True, strict=True,
) )
......
...@@ -218,9 +218,9 @@ def prune_heads(args, model, eval_dataloader, head_mask): ...@@ -218,9 +218,9 @@ def prune_heads(args, model, eval_dataloader, head_mask):
original_time = datetime.now() - before_time original_time = datetime.now() - before_time
original_num_params = sum(p.numel() for p in model.parameters()) original_num_params = sum(p.numel() for p in model.parameters())
heads_to_prune = dict( heads_to_prune = {
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask)) layer: (1 - head_mask[layer].long()).nonzero().squeeze().tolist() for layer in range(len(head_mask))
) }
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item() assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
......
...@@ -194,9 +194,9 @@ def prune_heads(args, model, eval_dataloader, head_mask): ...@@ -194,9 +194,9 @@ def prune_heads(args, model, eval_dataloader, head_mask):
original_time = datetime.now() - before_time original_time = datetime.now() - before_time
original_num_params = sum(p.numel() for p in model.parameters()) original_num_params = sum(p.numel() for p in model.parameters())
heads_to_prune = dict( heads_to_prune = {
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask)) layer: (1 - head_mask[layer].long()).nonzero().squeeze().tolist() for layer in range(len(head_mask))
) }
for k, v in heads_to_prune.items(): for k, v in heads_to_prune.items():
if isinstance(v, int): if isinstance(v, int):
......
...@@ -29,7 +29,7 @@ def get_min_hash(tokens: List[str]) -> Optional[MinHash]: ...@@ -29,7 +29,7 @@ def get_min_hash(tokens: List[str]) -> Optional[MinHash]:
def get_tokens(code: str) -> Set[str]: def get_tokens(code: str) -> Set[str]:
"""Tokenize a code snippet.""" """Tokenize a code snippet."""
return set([t for t in NON_ALPHA.split(code) if len(t.strip()) > 0]) return {t for t in NON_ALPHA.split(code) if len(t.strip()) > 0}
class DuplicationIndex: class DuplicationIndex:
...@@ -243,7 +243,7 @@ def deduplicate_dataset( ...@@ -243,7 +243,7 @@ def deduplicate_dataset(
>>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85) >>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85)
""" """
duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold) duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold)
duplicate_indices = set(x["base_index"] for cluster in duplicate_clusters for x in cluster) duplicate_indices = {x["base_index"] for cluster in duplicate_clusters for x in cluster}
extreme_dict = {} extreme_dict = {}
extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold) extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold)
for extremes in extremes_clusters: for extremes in extremes_clusters:
......
...@@ -114,7 +114,7 @@ def char_token_ratio(example): ...@@ -114,7 +114,7 @@ def char_token_ratio(example):
def preprocess(example): def preprocess(example):
"""Chain all preprocessing steps into one function to not fill cache.""" """Chain all preprocessing steps into one function to not fill cache."""
results = dict() results = {}
results.update(get_hash(example)) results.update(get_hash(example))
results.update(line_stats(example)) results.update(line_stats(example))
results.update(alpha_stats(example)) results.update(alpha_stats(example))
......
...@@ -8,7 +8,7 @@ from transformers import AutoTokenizer, HfArgumentParser ...@@ -8,7 +8,7 @@ from transformers import AutoTokenizer, HfArgumentParser
def tokenize(example): def tokenize(example):
output = dict() output = {}
output["input_ids"] = tokenizer(example["content"], truncation=False)["input_ids"] output["input_ids"] = tokenizer(example["content"], truncation=False)["input_ids"]
output["ratio_char_token"] = len(example["content"]) / len(output["input_ids"]) output["ratio_char_token"] = len(example["content"]) / len(output["input_ids"])
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment