Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
......@@ -5,7 +5,7 @@ target-version = ['py37']
[tool.ruff]
# Never enforce `E501` (line length violations).
ignore = ["E501", "E741", "W605"]
select = ["E", "F", "I", "W"]
select = ["C", "E", "F", "I", "W"]
line-length = 119
# Ignore import violations in all `__init__.py` files.
......
......@@ -557,9 +557,9 @@ def stop_memory_tracing(
cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
cumulative_memory = sorted(
list(cumulative_memory_dict.items()), key=lambda x: x[1][2], reverse=True
cumulative_memory_dict.items(), key=lambda x: x[1][2], reverse=True
) # order by the total CPU + GPU memory increase
cumulative_memory = list(
cumulative_memory = [
MemoryState(
frame=frame,
cpu=Memory(cpu_mem_inc),
......@@ -567,7 +567,7 @@ def stop_memory_tracing(
cpu_gpu=Memory(cpu_gpu_mem_inc),
)
for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
)
]
memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True)
......
......@@ -324,7 +324,7 @@ class PretrainedConfig(PushToHubMixin):
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
)
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.id2label = {int(key): value for key, value in self.id2label.items()}
# Keys are always strings in JSON so convert ids to int here.
else:
self.num_labels = kwargs.pop("num_labels", 2)
......@@ -696,7 +696,7 @@ class PretrainedConfig(PushToHubMixin):
config = cls(**config_dict)
if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
# Update config with kwargs if needed
if "num_labels" in kwargs and "id2label" in kwargs:
......
......@@ -367,13 +367,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
# keep for quick debug:
# from pprint import pprint; pprint(config)
kwargs = dict(
model=model,
model_parameters=model_parameters,
config_params=config,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
kwargs = {
"model": model,
"model_parameters": model_parameters,
"config_params": config,
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
}
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
......
......@@ -188,7 +188,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
truncated_inputs = []
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in processed_features.items())
inputs = {k: v[i] for k, v in processed_features.items()}
# truncation
inputs_slice = self._truncate(
inputs,
......
......@@ -208,12 +208,12 @@ class DisjunctiveTrie:
"""
self.max_height = max([len(one) for one in nested_token_ids])
root = dict()
root = {}
for token_ids in nested_token_ids:
level = root
for tidx, token_id in enumerate(token_ids):
if token_id not in level:
level[token_id] = dict()
level[token_id] = {}
level = level[token_id]
......
......@@ -951,7 +951,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
seq = [t for t in input_ids[k, self.begin_index :].tolist()]
seq = list(input_ids[k, self.begin_index :].tolist())
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
......
......@@ -115,7 +115,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
if is_valid_image(images):
if images.ndim == expected_ndims + 1:
# Batch of images
images = [image for image in images]
images = list(images)
elif images.ndim == expected_ndims:
# Single image
images = [images]
......
......@@ -365,7 +365,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
name="huggingface-tune",
type="offline",
parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
parallel_bandwidth=1,
budget=n_trials,
)
......@@ -402,7 +402,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
experiment = conn.experiments().create(
name="huggingface-tune",
parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
parallel_bandwidth=1,
observation_budget=n_trials,
project="huggingface",
......@@ -425,7 +425,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
values = [dict(name="objective", value=trainer.objective)]
values = [{"name": "objective", "value": trainer.objective}]
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
experiment = conn.experiments(experiment.id).fetch()
......
......@@ -162,7 +162,7 @@ class KerasMetricCallback(Callback):
def _postprocess_predictions_or_labels(self, inputs):
if isinstance(inputs[0], dict):
outputs = dict()
outputs = {}
for key in inputs[0].keys():
outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
# If it's a dict with only one key, just return the array
......
......@@ -677,7 +677,7 @@ class TrainingSummary:
_, eval_lines, eval_results = parse_keras_history(keras_history)
else:
eval_lines = []
eval_results = dict()
eval_results = {}
hyperparameters = extract_hyperparameters_from_keras(model)
return cls(
......@@ -706,7 +706,7 @@ def parse_keras_history(logs):
# This looks like a `History` object
if not hasattr(logs, "epoch"):
# This history looks empty, return empty results
return None, [], dict()
return None, [], {}
logs.history["epoch"] = logs.epoch
logs = logs.history
else:
......@@ -716,7 +716,7 @@ def parse_keras_history(logs):
lines = []
for i in range(len(logs["epoch"])):
epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
values = dict()
values = {}
for k, v in epoch_dict.items():
if k.startswith("val_"):
k = "validation_" + k[4:]
......@@ -797,7 +797,7 @@ def parse_log_history(log_history):
def extract_hyperparameters_from_keras(model):
import tensorflow as tf
hyperparameters = dict()
hyperparameters = {}
if hasattr(model, "optimizer") and model.optimizer is not None:
hyperparameters["optimizer"] = model.optimizer.get_config()
else:
......
......@@ -76,7 +76,7 @@ def rename_key_and_reshape_tensor(
def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
"""Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
return len(set(random_flax_state_dict) & set([key, (model_prefix,) + key])) > 0
return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0
# layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
......@@ -122,10 +122,10 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
flax_state_dict = {}
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
)
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
)
# Need to change some parameters name to match Flax names
......@@ -179,10 +179,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
random_flax_state_dict = flatten_dict(flax_model.params)
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
)
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
)
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
......@@ -267,10 +267,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
pt_model_dict = pt_model.state_dict()
load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()])
pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()}
)
load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()])
pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()}
)
# keep track of unexpected & missing keys
......
......@@ -440,7 +440,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"""
# Load the index
state_sharded_dict = dict()
state_sharded_dict = {}
for shard_file in shard_files:
# load using msgpack utils
......@@ -708,19 +708,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
......
......@@ -258,7 +258,7 @@ def load_pytorch_state_dict_in_tf2_model(
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
tf_loaded_numel = 0
weight_value_tuples = []
all_pytorch_weights = set(list(pt_state_dict.keys()))
all_pytorch_weights = set(pt_state_dict.keys())
missing_keys = []
for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name
......@@ -425,7 +425,7 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_
)
tf_weights_map[pt_name] = (tf_weight, transpose)
all_tf_weights = set(list(tf_weights_map.keys()))
all_tf_weights = set(tf_weights_map.keys())
loaded_pt_weights_data_ptr = {}
missing_keys_pt = []
for pt_weight_name, pt_weight in current_pt_params_dict.items():
......
......@@ -584,7 +584,7 @@ def input_processing(func, config, **kwargs):
if "kwargs" in output:
del output["kwargs"]
cast_output = dict()
cast_output = {}
for key, val in output.items():
if isinstance(val, tf.Tensor) and val.dtype == tf.int64:
cast_output[key] = tf.cast(val, tf.int32)
......@@ -737,7 +737,7 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# the weight, we have to get rid of the first prefix of the name of the layer.
model_keys = set()
model_layer_map = dict()
model_layer_map = {}
for i, k in enumerate(model.weights):
if "model." in k.name or len(k.name.split("/")) == 1:
layer_name = k.name
......@@ -901,10 +901,10 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
)
# Find the missing layers from the high level list of layers
missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)
missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name)
# Find the unexpected layers from the high level list of layers
unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers]))
unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers})
saved_weight_names_set = set()
symbolic_weights_names = set()
weight_value_tuples = []
......@@ -1349,7 +1349,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else:
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np")
if collate_fn_args is None:
collate_fn_args = dict()
collate_fn_args = {}
if not isinstance(dataset, datasets.Dataset):
raise TypeError("Dataset argument should be a datasets.Dataset!")
......@@ -1471,7 +1471,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
elif "mc_labels" in arg_names:
return {"labels": "logits", "mc_labels": "mc_logits"}
else:
return dict()
return {}
def train_step(self, data):
"""
......@@ -2613,19 +2613,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
......
......@@ -1271,7 +1271,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
encoder_layer_pos = 0
for name, module in decoder_modules.items():
if name.isdigit():
......@@ -2304,19 +2304,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
......@@ -2474,7 +2474,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
loaded_state_dict_keys = list(state_dict.keys())
if low_cpu_mem_usage or use_keep_in_fp32_modules:
state_dict = None
......@@ -3046,12 +3046,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = set([".".join(key.split(".")[:-1]) for key in names])
module_keys = {".".join(key.split(".")[:-1]) for key in names}
# torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys = module_keys.union(
set([".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()])
{".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()}
)
retrieved_modules = []
......
......@@ -555,7 +555,7 @@ class FlaxBeitEncoder(nn.Module):
)
# stochastic depth decay rule
drop_path_rates = [x for x in np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers)]
drop_path_rates = list(np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers))
self.layer = FlaxBeitLayerCollection(
self.config,
window_size=self.window_size,
......
......@@ -318,7 +318,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
split_tokens = []
words = re.findall(r"\S+\n?", text)
for token in words:
split_tokens.extend([t for t in self.bpe(token).split(" ")])
split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens
def normalizeTweet(self, tweet):
......@@ -726,7 +726,7 @@ class TweetTokenizer:
words = WORD_RE.findall(safe_text)
# Possibly alter the case, but avoid changing emoticons like :D into :d:
if not self.preserve_case:
words = list(map((lambda x: x if EMOTICON_RE.search(x) else x.lower()), words))
words = [x if EMOTICON_RE.search(x) else x.lower() for x in words]
return words
......
......@@ -202,7 +202,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
......
......@@ -132,8 +132,8 @@ class BioGptTokenizer(PreTrainedTokenizer):
self.lang = "en"
self.sm = sacremoses
# cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = dict()
self.cache_moses_detokenizer = dict()
self.cache_moses_tokenizer = {}
self.cache_moses_detokenizer = {}
""" Initialisation"""
with open(vocab_file, encoding="utf-8") as vocab_handle:
......@@ -221,7 +221,7 @@ class BioGptTokenizer(PreTrainedTokenizer):
split_tokens = []
for token in text:
if token:
split_tokens.extend([t for t in self.bpe(token).split(" ")])
split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment