Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
...@@ -60,7 +60,7 @@ class Tracker: ...@@ -60,7 +60,7 @@ class Tracker:
for name, m in self.module.named_modules(): for name, m in self.module.named_modules():
self.handles.append(m.register_forward_hook(partial(self._forward_hook, name=name))) self.handles.append(m.register_forward_hook(partial(self._forward_hook, name=name)))
self.module(x) self.module(x)
list(map(lambda x: x.remove(), self.handles)) [x.remove() for x in self.handles]
return self return self
@property @property
......
...@@ -53,7 +53,7 @@ class Tracker: ...@@ -53,7 +53,7 @@ class Tracker:
for m in self.module.modules(): for m in self.module.modules():
self.handles.append(m.register_forward_hook(self._forward_hook)) self.handles.append(m.register_forward_hook(self._forward_hook))
self.module(x) self.module(x)
list(map(lambda x: x.remove(), self.handles)) [x.remove() for x in self.handles]
return self return self
@property @property
......
...@@ -247,7 +247,7 @@ class TFRegNetStage(tf.keras.layers.Layer): ...@@ -247,7 +247,7 @@ class TFRegNetStage(tf.keras.layers.Layer):
class TFRegNetEncoder(tf.keras.layers.Layer): class TFRegNetEncoder(tf.keras.layers.Layer):
def __init__(self, config: RegNetConfig, **kwargs): def __init__(self, config: RegNetConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.stages = list() self.stages = []
# based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input
self.stages.append( self.stages.append(
TFRegNetStage( TFRegNetStage(
......
...@@ -219,7 +219,7 @@ class RemBertTokenizer(PreTrainedTokenizer): ...@@ -219,7 +219,7 @@ class RemBertTokenizer(PreTrainedTokenizer):
"You should not supply a second sequence if the provided sequence of " "You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model." "ids is already formatted with special tokens for the model."
) )
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
if token_ids_1 is not None: if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
......
...@@ -191,7 +191,7 @@ class RemBertTokenizerFast(PreTrainedTokenizerFast): ...@@ -191,7 +191,7 @@ class RemBertTokenizerFast(PreTrainedTokenizerFast):
"You should not supply a second sequence if the provided sequence of " "You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model." "ids is already formatted with special tokens for the model."
) )
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
if token_ids_1 is not None: if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
......
...@@ -51,7 +51,7 @@ class Tracker: ...@@ -51,7 +51,7 @@ class Tracker:
for m in self.module.modules(): for m in self.module.modules():
self.handles.append(m.register_forward_hook(self._forward_hook)) self.handles.append(m.register_forward_hook(self._forward_hook))
self.module(x) self.module(x)
list(map(lambda x: x.remove(), self.handles)) [x.remove() for x in self.handles]
return self return self
@property @property
......
...@@ -1240,7 +1240,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel): ...@@ -1240,7 +1240,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
sim_matrix = torch.matmul(pooled_output_norm, attack_pooled_output_norm.T) # batch_size * hidden_dim sim_matrix = torch.matmul(pooled_output_norm, attack_pooled_output_norm.T) # batch_size * hidden_dim
sim_matrix_target = torch.matmul(labels_pooled_output_norm, attack_pooled_output_norm.T) sim_matrix_target = torch.matmul(labels_pooled_output_norm, attack_pooled_output_norm.T)
batch_labels = torch.tensor([i for i in range(batch_size)], device=device) batch_labels = torch.tensor(list(range(batch_size)), device=device)
contrastive_loss = ( contrastive_loss = (
loss_fct(100 * sim_matrix.view(batch_size, -1), batch_labels.view(-1)) loss_fct(100 * sim_matrix.view(batch_size, -1), batch_labels.view(-1))
+ loss_fct(100 * sim_matrix_target.view(batch_size, -1), batch_labels.view(-1)) + loss_fct(100 * sim_matrix_target.view(batch_size, -1), batch_labels.view(-1))
......
...@@ -95,12 +95,10 @@ def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_ ...@@ -95,12 +95,10 @@ def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_
model = Speech2TextForConditionalGeneration(config) model = Speech2TextForConditionalGeneration(config)
missing, unexpected = model.model.load_state_dict(state_dict, strict=False) missing, unexpected = model.model.load_state_dict(state_dict, strict=False)
if len(missing) > 0 and not set(missing) <= set( if len(missing) > 0 and not set(missing) <= {
[ "encoder.embed_positions.weights",
"encoder.embed_positions.weights", "decoder.embed_positions.weights",
"decoder.embed_positions.weights", }:
]
):
raise ValueError( raise ValueError(
"Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing," "Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing,"
f" but all the following weights are missing {missing}" f" but all the following weights are missing {missing}"
......
...@@ -213,7 +213,7 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer): ...@@ -213,7 +213,7 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
split_tokens = [] split_tokens = []
for token in text: for token in text:
if token: if token:
split_tokens.extend([t for t in self.bpe(token).split(" ")]) split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens return split_tokens
......
...@@ -1259,7 +1259,7 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin): ...@@ -1259,7 +1259,7 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
self.out_feature_channels[stage] = num_features[i] self.out_feature_channels[stage] = num_features[i]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = dict() hidden_states_norms = {}
for stage, num_channels in zip(self.out_features, self.channels): for stage, num_channels in zip(self.out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels) hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
......
...@@ -1688,7 +1688,7 @@ class TapasTokenizer(PreTrainedTokenizer): ...@@ -1688,7 +1688,7 @@ class TapasTokenizer(PreTrainedTokenizer):
for col_index in range(num_columns): for col_index in range(num_columns):
for row_index in range(num_rows): for row_index in range(num_rows):
indices = [index for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index)] indices = list(self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index))
num_indices = len(indices) num_indices = len(indices)
if num_indices > 1: if num_indices > 1:
for index in indices: for index in indices:
......
...@@ -1453,16 +1453,16 @@ class TapexTokenizer(PreTrainedTokenizer): ...@@ -1453,16 +1453,16 @@ class TapexTokenizer(PreTrainedTokenizer):
truncated_unrelated_indices = [] truncated_unrelated_indices = []
related_indices = [] related_indices = []
if answer is None or len(answer) == 0: if answer is None or len(answer) == 0:
answer_set = set([]) answer_set = set()
else: else:
answer_set = set([ans_ex.lower() for ans_ex in answer]) answer_set = {ans_ex.lower() for ans_ex in answer}
# add question key words into answer set # add question key words into answer set
if question is not None: if question is not None:
answer_set.update(question.split()) answer_set.update(question.split())
question_set = set(question.strip("?!.,").split(" ")) question_set = set(question.strip("?!.,").split(" "))
row_max_len = len(table_content["rows"]) row_max_len = len(table_content["rows"])
for _row_idx, row in enumerate(table_content["rows"]): for _row_idx, row in enumerate(table_content["rows"]):
lower_row = set([str(cell).lower() for cell in row]) lower_row = {str(cell).lower() for cell in row}
if len(lower_row & answer_set) == 0 and len(lower_row & question_set) == 0: if len(lower_row & answer_set) == 0 and len(lower_row & question_set) == 0:
truncated_unrelated_indices.append(_row_idx) truncated_unrelated_indices.append(_row_idx)
else: else:
......
...@@ -55,7 +55,7 @@ class Tracker: ...@@ -55,7 +55,7 @@ class Tracker:
for m in self.module.modules(): for m in self.module.modules():
self.handles.append(m.register_forward_hook(self._forward_hook)) self.handles.append(m.register_forward_hook(self._forward_hook))
self.module(x) self.module(x)
list(map(lambda x: x.remove(), self.handles)) [x.remove() for x in self.handles]
return self return self
@property @property
......
...@@ -171,7 +171,7 @@ class ViltEmbeddings(nn.Module): ...@@ -171,7 +171,7 @@ class ViltEmbeddings(nn.Module):
non_valid_nums = [v.size(0) for v in non_valid_row_idx] non_valid_nums = [v.size(0) for v in non_valid_row_idx]
pad_nums = [max_image_length - v for v in valid_nums] pad_nums = [max_image_length - v for v in valid_nums]
select = list() select = []
for i, (v, nv, p) in enumerate(zip(valid_nums, non_valid_nums, pad_nums)): for i, (v, nv, p) in enumerate(zip(valid_nums, non_valid_nums, pad_nums)):
if p <= 0: if p <= 0:
valid_choice = torch.multinomial(torch.ones(v).float(), max_image_length) valid_choice = torch.multinomial(torch.ones(v).float(), max_image_length)
......
...@@ -648,7 +648,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -648,7 +648,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
if self.verbose: if self.verbose:
logger.info(f"Adding {token} to the vocabulary") logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add)) added_tok_encoder = {tok: len(self) + i for i, tok in enumerate(tokens_to_add)}
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder) self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder) self.added_tokens_decoder.update(added_tok_decoder)
......
...@@ -615,7 +615,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer): ...@@ -615,7 +615,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
if self.verbose: if self.verbose:
logger.info(f"Adding {token} to the vocabulary") logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add)) added_tok_encoder = {tok: len(self) + i for i, tok in enumerate(tokens_to_add)}
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder) self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder) self.added_tokens_decoder.update(added_tok_decoder)
......
...@@ -157,12 +157,10 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path): ...@@ -157,12 +157,10 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
model = WhisperForConditionalGeneration(config) model = WhisperForConditionalGeneration(config)
missing, unexpected = model.model.load_state_dict(state_dict, strict=False) missing, unexpected = model.model.load_state_dict(state_dict, strict=False)
if len(missing) > 0 and not set(missing) <= set( if len(missing) > 0 and not set(missing) <= {
[ "encoder.embed_positions.weights",
"encoder.embed_positions.weights", "decoder.embed_positions.weights",
"decoder.embed_positions.weights", }:
]
):
raise ValueError( raise ValueError(
"Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing," "Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing,"
f" but all the following weights are missing {missing}" f" but all the following weights are missing {missing}"
......
...@@ -189,25 +189,23 @@ class EnglishNumberNormalizer: ...@@ -189,25 +189,23 @@ class EnglishNumberNormalizer:
} }
self.specials = {"and", "double", "triple", "point"} self.specials = {"and", "double", "triple", "point"}
self.words = set( self.words = {
[ key
key for mapping in [
for mapping in [ self.zeros,
self.zeros, self.ones,
self.ones, self.ones_suffixed,
self.ones_suffixed, self.tens,
self.tens, self.tens_suffixed,
self.tens_suffixed, self.multipliers,
self.multipliers, self.multipliers_suffixed,
self.multipliers_suffixed, self.preceding_prefixers,
self.preceding_prefixers, self.following_prefixers,
self.following_prefixers, self.suffixers,
self.suffixers, self.specials,
self.specials,
]
for key in mapping
] ]
) for key in mapping
}
self.literal_words = {"one", "ones"} self.literal_words = {"one", "ones"}
def process_words(self, words: List[str]) -> Iterator[str]: def process_words(self, words: List[str]) -> Iterator[str]:
......
...@@ -43,10 +43,10 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p ...@@ -43,10 +43,10 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
two_levels_state_dict["transformer." + k] = v two_levels_state_dict["transformer." + k] = v
config = chkpt["params"] config = chkpt["params"]
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) config = {n: v for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))}
vocab = chkpt["dico_word2id"] vocab = chkpt["dico_word2id"]
vocab = dict((s + "</w>" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items()) vocab = {s + "</w>" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""): i for s, i in vocab.items()}
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
......
...@@ -638,10 +638,10 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -638,10 +638,10 @@ class XLMTokenizer(PreTrainedTokenizer):
self.sm = sacremoses self.sm = sacremoses
# cache of sm.MosesPunctNormalizer instance # cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict() self.cache_moses_punct_normalizer = {}
# cache of sm.MosesTokenizer instance # cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = dict() self.cache_moses_tokenizer = {}
self.lang_with_custom_tokenizer = set(["zh", "th", "ja"]) self.lang_with_custom_tokenizer = {"zh", "th", "ja"}
# True for current supported model (v1.2.0), False for XLM-17 & 100 # True for current supported model (v1.2.0), False for XLM-17 & 100
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
self.lang2id = lang2id self.lang2id = lang2id
...@@ -851,7 +851,7 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -851,7 +851,7 @@ class XLMTokenizer(PreTrainedTokenizer):
split_tokens = [] split_tokens = []
for token in text: for token in text:
if token: if token:
split_tokens.extend([t for t in self.bpe(token).split(" ")]) split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens return split_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment