"vscode:/vscode.git/clone" did not exist on "aa61f940598c8e979e8c5632aa61d7d23db7d5c0"
Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
......@@ -60,7 +60,7 @@ class Tracker:
for name, m in self.module.named_modules():
self.handles.append(m.register_forward_hook(partial(self._forward_hook, name=name)))
self.module(x)
list(map(lambda x: x.remove(), self.handles))
[x.remove() for x in self.handles]
return self
@property
......
......@@ -53,7 +53,7 @@ class Tracker:
for m in self.module.modules():
self.handles.append(m.register_forward_hook(self._forward_hook))
self.module(x)
list(map(lambda x: x.remove(), self.handles))
[x.remove() for x in self.handles]
return self
@property
......
......@@ -247,7 +247,7 @@ class TFRegNetStage(tf.keras.layers.Layer):
class TFRegNetEncoder(tf.keras.layers.Layer):
def __init__(self, config: RegNetConfig, **kwargs):
super().__init__(**kwargs)
self.stages = list()
self.stages = []
# based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input
self.stages.append(
TFRegNetStage(
......
......@@ -219,7 +219,7 @@ class RemBertTokenizer(PreTrainedTokenizer):
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
......
......@@ -191,7 +191,7 @@ class RemBertTokenizerFast(PreTrainedTokenizerFast):
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
......
......@@ -51,7 +51,7 @@ class Tracker:
for m in self.module.modules():
self.handles.append(m.register_forward_hook(self._forward_hook))
self.module(x)
list(map(lambda x: x.remove(), self.handles))
[x.remove() for x in self.handles]
return self
@property
......
......@@ -1240,7 +1240,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
sim_matrix = torch.matmul(pooled_output_norm, attack_pooled_output_norm.T) # batch_size * hidden_dim
sim_matrix_target = torch.matmul(labels_pooled_output_norm, attack_pooled_output_norm.T)
batch_labels = torch.tensor([i for i in range(batch_size)], device=device)
batch_labels = torch.tensor(list(range(batch_size)), device=device)
contrastive_loss = (
loss_fct(100 * sim_matrix.view(batch_size, -1), batch_labels.view(-1))
+ loss_fct(100 * sim_matrix_target.view(batch_size, -1), batch_labels.view(-1))
......
......@@ -95,12 +95,10 @@ def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_
model = Speech2TextForConditionalGeneration(config)
missing, unexpected = model.model.load_state_dict(state_dict, strict=False)
if len(missing) > 0 and not set(missing) <= set(
[
if len(missing) > 0 and not set(missing) <= {
"encoder.embed_positions.weights",
"decoder.embed_positions.weights",
]
):
}:
raise ValueError(
"Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing,"
f" but all the following weights are missing {missing}"
......
......@@ -213,7 +213,7 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
split_tokens = []
for token in text:
if token:
split_tokens.extend([t for t in self.bpe(token).split(" ")])
split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens
......
......@@ -1259,7 +1259,7 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
self.out_feature_channels[stage] = num_features[i]
# Add layer norms to hidden states of out_features
hidden_states_norms = dict()
hidden_states_norms = {}
for stage, num_channels in zip(self.out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
......
......@@ -1688,7 +1688,7 @@ class TapasTokenizer(PreTrainedTokenizer):
for col_index in range(num_columns):
for row_index in range(num_rows):
indices = [index for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index)]
indices = list(self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index))
num_indices = len(indices)
if num_indices > 1:
for index in indices:
......
......@@ -1453,16 +1453,16 @@ class TapexTokenizer(PreTrainedTokenizer):
truncated_unrelated_indices = []
related_indices = []
if answer is None or len(answer) == 0:
answer_set = set([])
answer_set = set()
else:
answer_set = set([ans_ex.lower() for ans_ex in answer])
answer_set = {ans_ex.lower() for ans_ex in answer}
# add question key words into answer set
if question is not None:
answer_set.update(question.split())
question_set = set(question.strip("?!.,").split(" "))
row_max_len = len(table_content["rows"])
for _row_idx, row in enumerate(table_content["rows"]):
lower_row = set([str(cell).lower() for cell in row])
lower_row = {str(cell).lower() for cell in row}
if len(lower_row & answer_set) == 0 and len(lower_row & question_set) == 0:
truncated_unrelated_indices.append(_row_idx)
else:
......
......@@ -55,7 +55,7 @@ class Tracker:
for m in self.module.modules():
self.handles.append(m.register_forward_hook(self._forward_hook))
self.module(x)
list(map(lambda x: x.remove(), self.handles))
[x.remove() for x in self.handles]
return self
@property
......
......@@ -171,7 +171,7 @@ class ViltEmbeddings(nn.Module):
non_valid_nums = [v.size(0) for v in non_valid_row_idx]
pad_nums = [max_image_length - v for v in valid_nums]
select = list()
select = []
for i, (v, nv, p) in enumerate(zip(valid_nums, non_valid_nums, pad_nums)):
if p <= 0:
valid_choice = torch.multinomial(torch.ones(v).float(), max_image_length)
......
......@@ -648,7 +648,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
if self.verbose:
logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_encoder = {tok: len(self) + i for i, tok in enumerate(tokens_to_add)}
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
......
......@@ -615,7 +615,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
if self.verbose:
logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_encoder = {tok: len(self) + i for i, tok in enumerate(tokens_to_add)}
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
......
......@@ -157,12 +157,10 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
model = WhisperForConditionalGeneration(config)
missing, unexpected = model.model.load_state_dict(state_dict, strict=False)
if len(missing) > 0 and not set(missing) <= set(
[
if len(missing) > 0 and not set(missing) <= {
"encoder.embed_positions.weights",
"decoder.embed_positions.weights",
]
):
}:
raise ValueError(
"Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing,"
f" but all the following weights are missing {missing}"
......
......@@ -189,8 +189,7 @@ class EnglishNumberNormalizer:
}
self.specials = {"and", "double", "triple", "point"}
self.words = set(
[
self.words = {
key
for mapping in [
self.zeros,
......@@ -206,8 +205,7 @@ class EnglishNumberNormalizer:
self.specials,
]
for key in mapping
]
)
}
self.literal_words = {"one", "ones"}
def process_words(self, words: List[str]) -> Iterator[str]:
......
......@@ -43,10 +43,10 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
two_levels_state_dict["transformer." + k] = v
config = chkpt["params"]
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
config = {n: v for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))}
vocab = chkpt["dico_word2id"]
vocab = dict((s + "</w>" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items())
vocab = {s + "</w>" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""): i for s, i in vocab.items()}
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
......
......@@ -638,10 +638,10 @@ class XLMTokenizer(PreTrainedTokenizer):
self.sm = sacremoses
# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict()
self.cache_moses_punct_normalizer = {}
# cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = dict()
self.lang_with_custom_tokenizer = set(["zh", "th", "ja"])
self.cache_moses_tokenizer = {}
self.lang_with_custom_tokenizer = {"zh", "th", "ja"}
# True for current supported model (v1.2.0), False for XLM-17 & 100
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
self.lang2id = lang2id
......@@ -851,7 +851,7 @@ class XLMTokenizer(PreTrainedTokenizer):
split_tokens = []
for token in text:
if token:
split_tokens.extend([t for t in self.bpe(token).split(" ")])
split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment