Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
...@@ -42,7 +42,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p ...@@ -42,7 +42,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
# Add special tokens to the token vocabulary for downstream tasks # Add special tokens to the token vocabulary for downstream tasks
entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False) entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False)
entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False) entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False)
tokenizer.add_special_tokens(dict(additional_special_tokens=[entity_token_1, entity_token_2])) tokenizer.add_special_tokens({"additional_special_tokens": [entity_token_1, entity_token_2]})
config.vocab_size += 2 config.vocab_size += 2
print(f"Saving tokenizer to {pytorch_dump_folder_path}") print(f"Saving tokenizer to {pytorch_dump_folder_path}")
......
...@@ -1529,7 +1529,7 @@ class LukeTokenizer(PreTrainedTokenizer): ...@@ -1529,7 +1529,7 @@ class LukeTokenizer(PreTrainedTokenizer):
batch_outputs = {} batch_outputs = {}
for i in range(batch_size): for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) inputs = {k: v[i] for k, v in encoded_inputs.items()}
outputs = self._pad( outputs = self._pad(
inputs, inputs,
max_length=max_length, max_length=max_length,
......
...@@ -185,12 +185,12 @@ def convert_hf_name_to_opus_name(hf_model_name): ...@@ -185,12 +185,12 @@ def convert_hf_name_to_opus_name(hf_model_name):
def get_system_metadata(repo_root): def get_system_metadata(repo_root):
import git import git
return dict( return {
helsinki_git_sha=git.Repo(path=repo_root, search_parent_directories=True).head.object.hexsha, "helsinki_git_sha": git.Repo(path=repo_root, search_parent_directories=True).head.object.hexsha,
transformers_git_sha=git.Repo(path=".", search_parent_directories=True).head.object.hexsha, "transformers_git_sha": git.Repo(path=".", search_parent_directories=True).head.object.hexsha,
port_machine=socket.gethostname(), "port_machine": socket.gethostname(),
port_time=time.strftime("%Y-%m-%d-%H:%M"), "port_time": time.strftime("%Y-%m-%d-%H:%M"),
) }
# docstyle-ignore # docstyle-ignore
...@@ -366,7 +366,7 @@ def _parse_readme(lns): ...@@ -366,7 +366,7 @@ def _parse_readme(lns):
def save_tokenizer_config(dest_dir: Path, separate_vocabs=False): def save_tokenizer_config(dest_dir: Path, separate_vocabs=False):
dname = dest_dir.name.split("-") dname = dest_dir.name.split("-")
dct = dict(target_lang=dname[-1], source_lang="-".join(dname[:-1]), separate_vocabs=separate_vocabs) dct = {"target_lang": dname[-1], "source_lang": "-".join(dname[:-1]), "separate_vocabs": separate_vocabs}
save_json(dct, dest_dir / "tokenizer_config.json") save_json(dct, dest_dir / "tokenizer_config.json")
......
...@@ -76,7 +76,7 @@ class TrackedStateDict: ...@@ -76,7 +76,7 @@ class TrackedStateDict:
Returns: Returns:
List[str]: List of keys not yet updated List[str]: List of keys not yet updated
""" """
return set(list(self.to_track.keys())) - self._seen return set(self.to_track.keys()) - self._seen
def copy(self) -> Dict: def copy(self) -> Dict:
# proxy the call to the internal dictionary # proxy the call to the internal dictionary
......
...@@ -119,7 +119,7 @@ def binary_mask_to_rle(mask): ...@@ -119,7 +119,7 @@ def binary_mask_to_rle(mask):
pixels = np.concatenate([[0], pixels, [0]]) pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2] runs[1::2] -= runs[::2]
return [x for x in runs] return list(runs)
# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle # Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
......
...@@ -72,7 +72,7 @@ class TrackedStateDict: ...@@ -72,7 +72,7 @@ class TrackedStateDict:
Returns: Returns:
List[str]: List of keys not yet updated List[str]: List of keys not yet updated
""" """
return set(list(self.to_track.keys())) - self._seen return set(self.to_track.keys()) - self._seen
def copy(self) -> Dict: def copy(self) -> Dict:
# proxy the call to the internal dictionary # proxy the call to the internal dictionary
...@@ -120,43 +120,43 @@ class OriginalMaskFormerConfigToOursConverter: ...@@ -120,43 +120,43 @@ class OriginalMaskFormerConfigToOursConverter:
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES, num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
no_object_weight=mask_former.NO_OBJECT_WEIGHT, no_object_weight=mask_former.NO_OBJECT_WEIGHT,
num_queries=mask_former.NUM_OBJECT_QUERIES, num_queries=mask_former.NUM_OBJECT_QUERIES,
backbone_config=dict( backbone_config={
pretrain_img_size=swin.PRETRAIN_IMG_SIZE, "pretrain_img_size": swin.PRETRAIN_IMG_SIZE,
image_size=swin.PRETRAIN_IMG_SIZE, "image_size": swin.PRETRAIN_IMG_SIZE,
in_channels=3, "in_channels": 3,
patch_size=swin.PATCH_SIZE, "patch_size": swin.PATCH_SIZE,
embed_dim=swin.EMBED_DIM, "embed_dim": swin.EMBED_DIM,
depths=swin.DEPTHS, "depths": swin.DEPTHS,
num_heads=swin.NUM_HEADS, "num_heads": swin.NUM_HEADS,
window_size=swin.WINDOW_SIZE, "window_size": swin.WINDOW_SIZE,
drop_path_rate=swin.DROP_PATH_RATE, "drop_path_rate": swin.DROP_PATH_RATE,
model_type="swin", "model_type": "swin",
), },
dice_weight=mask_former.DICE_WEIGHT, dice_weight=mask_former.DICE_WEIGHT,
ce_weight=1.0, ce_weight=1.0,
mask_weight=mask_former.MASK_WEIGHT, mask_weight=mask_former.MASK_WEIGHT,
decoder_config=dict( decoder_config={
model_type="detr", "model_type": "detr",
max_position_embeddings=1024, "max_position_embeddings": 1024,
encoder_layers=6, "encoder_layers": 6,
encoder_ffn_dim=2048, "encoder_ffn_dim": 2048,
encoder_attention_heads=8, "encoder_attention_heads": 8,
decoder_layers=mask_former.DEC_LAYERS, "decoder_layers": mask_former.DEC_LAYERS,
decoder_ffn_dim=mask_former.DIM_FEEDFORWARD, "decoder_ffn_dim": mask_former.DIM_FEEDFORWARD,
decoder_attention_heads=mask_former.NHEADS, "decoder_attention_heads": mask_former.NHEADS,
encoder_layerdrop=0.0, "encoder_layerdrop": 0.0,
decoder_layerdrop=0.0, "decoder_layerdrop": 0.0,
d_model=mask_former.HIDDEN_DIM, "d_model": mask_former.HIDDEN_DIM,
dropout=mask_former.DROPOUT, "dropout": mask_former.DROPOUT,
attention_dropout=0.0, "attention_dropout": 0.0,
activation_dropout=0.0, "activation_dropout": 0.0,
init_std=0.02, "init_std": 0.02,
init_xavier_std=1.0, "init_xavier_std": 1.0,
scale_embedding=False, "scale_embedding": False,
auxiliary_loss=False, "auxiliary_loss": False,
dilation=False, "dilation": False,
# default pretrained config values # default pretrained config values
), },
id2label=id2label, id2label=id2label,
label2id=label2id, label2id=label2id,
) )
......
...@@ -123,7 +123,7 @@ def binary_mask_to_rle(mask): ...@@ -123,7 +123,7 @@ def binary_mask_to_rle(mask):
pixels = np.concatenate([[0], pixels, [0]]) pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2] runs[1::2] -= runs[::2]
return [x for x in runs] return list(runs)
# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle # Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
......
...@@ -46,7 +46,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p ...@@ -46,7 +46,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
# Add special tokens to the token vocabulary for downstream tasks # Add special tokens to the token vocabulary for downstream tasks
entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False) entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False)
entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False) entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False)
tokenizer.add_special_tokens(dict(additional_special_tokens=[entity_token_1, entity_token_2])) tokenizer.add_special_tokens({"additional_special_tokens": [entity_token_1, entity_token_2]})
config.vocab_size += 2 config.vocab_size += 2
print(f"Saving tokenizer to {pytorch_dump_folder_path}") print(f"Saving tokenizer to {pytorch_dump_folder_path}")
......
...@@ -1328,7 +1328,7 @@ class MLukeTokenizer(PreTrainedTokenizer): ...@@ -1328,7 +1328,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
batch_outputs = {} batch_outputs = {}
for i in range(batch_size): for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) inputs = {k: v[i] for k, v in encoded_inputs.items()}
outputs = self._pad( outputs = self._pad(
inputs, inputs,
max_length=max_length, max_length=max_length,
......
...@@ -877,7 +877,7 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin): ...@@ -877,7 +877,7 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
self.out_feature_channels[stage] = num_features[i] self.out_feature_channels[stage] = num_features[i]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = dict() hidden_states_norms = {}
for stage, num_channels in zip(self.out_features, self.channels): for stage, num_channels in zip(self.out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels) hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
......
...@@ -82,7 +82,7 @@ class TrackedStateDict: ...@@ -82,7 +82,7 @@ class TrackedStateDict:
Returns: Returns:
List[str]: List of keys not yet updated List[str]: List of keys not yet updated
""" """
return set(list(self.to_track.keys())) - self._seen return set(self.to_track.keys()) - self._seen
def copy(self) -> Dict: def copy(self) -> Dict:
# proxy the call to the internal dictionary # proxy the call to the internal dictionary
......
...@@ -120,7 +120,7 @@ def binary_mask_to_rle(mask): ...@@ -120,7 +120,7 @@ def binary_mask_to_rle(mask):
pixels = np.concatenate([[0], pixels, [0]]) pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2] runs[1::2] -= runs[::2]
return [x for x in runs] return list(runs)
# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle # Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
......
...@@ -342,12 +342,12 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -342,12 +342,12 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
# Using BERT's BasicTokenizer # Using BERT's BasicTokenizer
text = self.nlp.tokenize(text) text = self.nlp.tokenize(text)
for token in text: for token in text:
split_tokens.extend([t for t in self.bpe(token).split(" ")]) split_tokens.extend(list(self.bpe(token).split(" ")))
else: else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT) # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text))) text = self.nlp(text_standardize(self.fix_text(text)))
for token in text: for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(" ")]) split_tokens.extend(list(self.bpe(token.text.lower()).split(" ")))
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
......
...@@ -37,42 +37,42 @@ from transformers import ( ...@@ -37,42 +37,42 @@ from transformers import (
CONFIGS = { CONFIGS = {
"vit_b32": dict( "vit_b32": {
embed_dim=512, "embed_dim": 512,
image_resolution=768, "image_resolution": 768,
context_length=16, "context_length": 16,
vocab_size=49408, "vocab_size": 49408,
vision_layers=12, "vision_layers": 12,
vision_width=768, "vision_width": 768,
vision_patch_size=32, "vision_patch_size": 32,
transformer_width=512, "transformer_width": 512,
transformer_heads=8, "transformer_heads": 8,
transformer_layers=12, "transformer_layers": 12,
), },
"vit_b16": dict( "vit_b16": {
embed_dim=512, "embed_dim": 512,
image_resolution=768, "image_resolution": 768,
context_length=16, "context_length": 16,
vocab_size=49408, "vocab_size": 49408,
vision_layers=12, "vision_layers": 12,
vision_width=768, "vision_width": 768,
vision_patch_size=16, "vision_patch_size": 16,
transformer_width=512, "transformer_width": 512,
transformer_heads=8, "transformer_heads": 8,
transformer_layers=12, "transformer_layers": 12,
), },
"vit_l14": dict( "vit_l14": {
embed_dim=768, "embed_dim": 768,
image_resolution=840, "image_resolution": 840,
context_length=16, "context_length": 16,
vocab_size=49408, "vocab_size": 49408,
vision_layers=24, "vision_layers": 24,
vision_width=1024, "vision_width": 1024,
vision_patch_size=14, "vision_patch_size": 14,
transformer_width=768, "transformer_width": 768,
transformer_heads=12, "transformer_heads": 12,
transformer_layers=12, "transformer_layers": 12,
), },
} }
......
...@@ -283,7 +283,7 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architec ...@@ -283,7 +283,7 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architec
params = checkpoint params = checkpoint
# turn into initial state dict # turn into initial state dict
state_dict = dict() state_dict = {}
for scope_name, parameters in hk.data_structures.to_mutable_dict(params).items(): for scope_name, parameters in hk.data_structures.to_mutable_dict(params).items():
for param_name, param in parameters.items(): for param_name, param in parameters.items():
state_dict[scope_name + "/" + param_name] = param state_dict[scope_name + "/" + param_name] = param
...@@ -398,7 +398,7 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architec ...@@ -398,7 +398,7 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architec
elif architecture == "multimodal_autoencoding": elif architecture == "multimodal_autoencoding":
images = torch.randn((1, 16, 3, 224, 224)) images = torch.randn((1, 16, 3, 224, 224))
audio = torch.randn((1, 30720, 1)) audio = torch.randn((1, 30720, 1))
inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700))) inputs = {"image": images, "audio": audio, "label": torch.zeros((images.shape[0], 700))}
# forward pass # forward pass
if architecture == "multimodal_autoencoding": if architecture == "multimodal_autoencoding":
......
...@@ -957,9 +957,10 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel): ...@@ -957,9 +957,10 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
text_preprocessor = PerceiverTextPreprocessor(config) text_preprocessor = PerceiverTextPreprocessor(config)
trainable_position_encoding_kwargs_decoder = dict( trainable_position_encoding_kwargs_decoder = {
num_channels=text_preprocessor.num_channels, index_dims=config.max_position_embeddings "num_channels": text_preprocessor.num_channels,
) "index_dims": config.max_position_embeddings,
}
self.perceiver = PerceiverModel( self.perceiver = PerceiverModel(
config, config,
...@@ -1089,7 +1090,7 @@ class PerceiverForSequenceClassification(PerceiverPreTrainedModel): ...@@ -1089,7 +1090,7 @@ class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
trainable_position_encoding_kwargs_decoder = dict(num_channels=config.d_latents, index_dims=1) trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.perceiver = PerceiverModel( self.perceiver = PerceiverModel(
...@@ -1214,8 +1215,8 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel): ...@@ -1214,8 +1215,8 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
trainable_position_encoding_kwargs_preprocessor = dict(num_channels=256, index_dims=config.image_size**2) trainable_position_encoding_kwargs_preprocessor = {"num_channels": 256, "index_dims": config.image_size**2}
trainable_position_encoding_kwargs_decoder = dict(num_channels=config.d_latents, index_dims=1) trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.perceiver = PerceiverModel( self.perceiver = PerceiverModel(
...@@ -1357,10 +1358,13 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel): ...@@ -1357,10 +1358,13 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
fourier_position_encoding_kwargs_preprocessor = dict( fourier_position_encoding_kwargs_preprocessor = {
concat_pos=True, max_resolution=(224, 224), num_bands=64, sine_only=False "concat_pos": True,
) "max_resolution": (224, 224),
trainable_position_encoding_kwargs_decoder = dict(num_channels=config.d_latents, index_dims=1) "num_bands": 64,
"sine_only": False,
}
trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.perceiver = PerceiverModel( self.perceiver = PerceiverModel(
...@@ -1497,10 +1501,13 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel): ...@@ -1497,10 +1501,13 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
fourier_position_encoding_kwargs_preprocessor = dict( fourier_position_encoding_kwargs_preprocessor = {
concat_pos=True, max_resolution=(56, 56), num_bands=64, sine_only=False "concat_pos": True,
) "max_resolution": (56, 56),
trainable_position_encoding_kwargs_decoder = dict(num_channels=config.d_latents, index_dims=1) "num_bands": 64,
"sine_only": False,
}
trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.perceiver = PerceiverModel( self.perceiver = PerceiverModel(
...@@ -1638,15 +1645,18 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel): ...@@ -1638,15 +1645,18 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
fourier_position_encoding_kwargs_preprocessor = dict( fourier_position_encoding_kwargs_preprocessor = {
num_bands=64, "num_bands": 64,
max_resolution=config.train_size, "max_resolution": config.train_size,
sine_only=False, "sine_only": False,
concat_pos=True, "concat_pos": True,
) }
fourier_position_encoding_kwargs_decoder = dict( fourier_position_encoding_kwargs_decoder = {
concat_pos=True, max_resolution=config.train_size, num_bands=64, sine_only=False "concat_pos": True,
) "max_resolution": config.train_size,
"num_bands": 64,
"sine_only": False,
}
image_preprocessor = PerceiverImagePreprocessor( image_preprocessor = PerceiverImagePreprocessor(
config, config,
...@@ -1788,24 +1798,24 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): ...@@ -1788,24 +1798,24 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
"audio": PerceiverAudioPreprocessor( "audio": PerceiverAudioPreprocessor(
config, config,
position_encoding_type="fourier", position_encoding_type="fourier",
fourier_position_encoding_kwargs=dict( fourier_position_encoding_kwargs={
num_bands=192, "num_bands": 192,
max_resolution=(n_audio_samples,), "max_resolution": (n_audio_samples,),
sine_only=False, "sine_only": False,
concat_pos=True, "concat_pos": True,
), },
prep_type="patches", prep_type="patches",
samples_per_patch=config.samples_per_patch, samples_per_patch=config.samples_per_patch,
), ),
"image": PerceiverImagePreprocessor( "image": PerceiverImagePreprocessor(
config, config,
position_encoding_type="fourier", position_encoding_type="fourier",
fourier_position_encoding_kwargs=dict( fourier_position_encoding_kwargs={
num_bands=32, "num_bands": 32,
max_resolution=(config.num_frames, config.image_size, config.image_size), "max_resolution": (config.num_frames, config.image_size, config.image_size),
sine_only=False, "sine_only": False,
concat_pos=True, "concat_pos": True,
), },
prep_type="patches", prep_type="patches",
spatial_downsample=4, spatial_downsample=4,
temporal_downsample=1, temporal_downsample=1,
...@@ -1824,12 +1834,12 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): ...@@ -1824,12 +1834,12 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
use_query_residual=False, use_query_residual=False,
position_encoding_only=True, position_encoding_only=True,
position_encoding_type="fourier", position_encoding_type="fourier",
fourier_position_encoding_kwargs=dict( fourier_position_encoding_kwargs={
num_bands=32, "num_bands": 32,
max_resolution=(config.num_frames, config.image_size, config.image_size), "max_resolution": (config.num_frames, config.image_size, config.image_size),
sine_only=False, "sine_only": False,
concat_pos=True, "concat_pos": True,
), },
) )
decoder = PerceiverMultimodalDecoder( decoder = PerceiverMultimodalDecoder(
...@@ -1848,12 +1858,12 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): ...@@ -1848,12 +1858,12 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
use_query_residual=False, use_query_residual=False,
position_encoding_only=True, position_encoding_only=True,
position_encoding_type="fourier", position_encoding_type="fourier",
fourier_position_encoding_kwargs=dict( fourier_position_encoding_kwargs={
num_bands=192, "num_bands": 192,
max_resolution=(n_audio_samples,), "max_resolution": (n_audio_samples,),
sine_only=False, "sine_only": False,
concat_pos=True, "concat_pos": True,
), },
), ),
"image": image_decoder, "image": image_decoder,
"label": PerceiverClassificationDecoder( "label": PerceiverClassificationDecoder(
...@@ -1863,10 +1873,10 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): ...@@ -1863,10 +1873,10 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
use_query_residual=False, use_query_residual=False,
position_encoding_only=True, position_encoding_only=True,
position_encoding_type="trainable", position_encoding_type="trainable",
trainable_position_encoding_kwargs=dict( trainable_position_encoding_kwargs={
num_channels=1024, "num_channels": 1024,
index_dims=1, "index_dims": 1,
), },
), ),
}, },
num_outputs=None, num_outputs=None,
...@@ -2180,9 +2190,7 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): ...@@ -2180,9 +2190,7 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
# to get the indices for the unflattened array # to get the indices for the unflattened array
# unravel_index returns a tuple (x_idx, y_idx, ...) # unravel_index returns a tuple (x_idx, y_idx, ...)
# stack to get the [n, d] tensor of coordinates # stack to get the [n, d] tensor of coordinates
indices = list( indices = [torch.from_numpy(x) for x in np.unravel_index(subsampled_points.cpu(), self.output_index_dims)]
torch.from_numpy(x) for x in np.unravel_index(subsampled_points.cpu(), self.output_index_dims)
)
pos = torch.stack(indices, dim=1) pos = torch.stack(indices, dim=1)
batch_size = inputs.shape[0] batch_size = inputs.shape[0]
# Map these coordinates to [-1, 1] # Map these coordinates to [-1, 1]
...@@ -2476,9 +2484,9 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder): ...@@ -2476,9 +2484,9 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
inputs = restructure(modality_sizes, inputs) inputs = restructure(modality_sizes, inputs)
# Obtain modality-specific decoders' queries # Obtain modality-specific decoders' queries
subsampled_points = subsampled_points or dict() subsampled_points = subsampled_points or {}
decoder_queries = dict() decoder_queries = {}
for modality, decoder in self.modalities.items(): for modality, decoder in self.modalities.items():
# Get input_without_pos for this modality if it exists. # Get input_without_pos for this modality if it exists.
input_without_pos = None input_without_pos = None
...@@ -3363,7 +3371,7 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor): ...@@ -3363,7 +3371,7 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
super().__init__() super().__init__()
self.modalities = nn.ModuleDict(modalities) self.modalities = nn.ModuleDict(modalities)
self.min_padding_size = min_padding_size self.min_padding_size = min_padding_size
self.mask_probs = mask_probs if mask_probs is not None else dict() self.mask_probs = mask_probs if mask_probs is not None else {}
self.padding = nn.ParameterDict( self.padding = nn.ParameterDict(
{ {
modality: nn.Parameter(torch.randn(1, self.num_channels - preprocessor.num_channels)) modality: nn.Parameter(torch.randn(1, self.num_channels - preprocessor.num_channels))
......
...@@ -297,7 +297,7 @@ class PhobertTokenizer(PreTrainedTokenizer): ...@@ -297,7 +297,7 @@ class PhobertTokenizer(PreTrainedTokenizer):
words = re.findall(r"\S+\n?", text) words = re.findall(r"\S+\n?", text)
for token in words: for token in words:
split_tokens.extend([t for t in self.bpe(token).split(" ")]) split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
......
...@@ -294,7 +294,7 @@ class RealmTokenizer(PreTrainedTokenizer): ...@@ -294,7 +294,7 @@ class RealmTokenizer(PreTrainedTokenizer):
if encoded_token_type_ids is not None: if encoded_token_type_ids is not None:
output_data["token_type_ids"].append(encoded_token_type_ids) output_data["token_type_ids"].append(encoded_token_type_ids)
output_data = dict((key, item) for key, item in output_data.items() if len(item) != 0) output_data = {key: item for key, item in output_data.items() if len(item) != 0}
return BatchEncoding(output_data, tensor_type=return_tensors) return BatchEncoding(output_data, tensor_type=return_tensors)
......
...@@ -259,7 +259,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast): ...@@ -259,7 +259,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast):
if encoded_token_type_ids is not None: if encoded_token_type_ids is not None:
output_data["token_type_ids"].append(encoded_token_type_ids) output_data["token_type_ids"].append(encoded_token_type_ids)
output_data = dict((key, item) for key, item in output_data.items() if len(item) != 0) output_data = {key: item for key, item in output_data.items() if len(item) != 0}
return BatchEncoding(output_data, tensor_type=return_tensors) return BatchEncoding(output_data, tensor_type=return_tensors)
......
...@@ -87,7 +87,7 @@ def _get_least_common_mult_chunk_len(config): ...@@ -87,7 +87,7 @@ def _get_least_common_mult_chunk_len(config):
return config.lsh_attn_chunk_length return config.lsh_attn_chunk_length
elif len(attn_types_set) == 1 and attn_types[0] == "local": elif len(attn_types_set) == 1 and attn_types[0] == "local":
return config.local_attn_chunk_length return config.local_attn_chunk_length
elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]): elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:
return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length) return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
else: else:
raise NotImplementedError( raise NotImplementedError(
...@@ -103,7 +103,7 @@ def _get_min_chunk_len(config): ...@@ -103,7 +103,7 @@ def _get_min_chunk_len(config):
return config.lsh_attn_chunk_length return config.lsh_attn_chunk_length
elif len(attn_types_set) == 1 and attn_types[0] == "local": elif len(attn_types_set) == 1 and attn_types[0] == "local":
return config.local_attn_chunk_length return config.local_attn_chunk_length
elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]): elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:
return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length) return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
else: else:
raise NotImplementedError( raise NotImplementedError(
...@@ -1277,7 +1277,7 @@ class ReformerAttention(nn.Module): ...@@ -1277,7 +1277,7 @@ class ReformerAttention(nn.Module):
self.self_attention = LSHSelfAttention(config) self.self_attention = LSHSelfAttention(config)
elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local":
self.self_attention = LocalSelfAttention(config) self.self_attention = LocalSelfAttention(config)
elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == set(["lsh", "local"]): elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == {"lsh", "local"}:
# get correct attn layers # get correct attn layers
if self.attn_layers[self.layer_id] == "lsh": if self.attn_layers[self.layer_id] == "lsh":
self.self_attention = LSHSelfAttention(config) self.self_attention = LSHSelfAttention(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment