Unverified Commit 7fcee113 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tpu tie weights (#13030)

* Fix tied weights on TPU

* Manually tie weights in no trainer examples

* Fix for test

* One last missing

* Gettning owned by my scripts

* Address review comments

* Fix test

* Fix tests

* Fix reformer tests
parent 1bf38611
...@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader ...@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator, DistributedType
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
...@@ -403,6 +403,10 @@ def main(): ...@@ -403,6 +403,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader model, optimizer, train_dataloader, eval_dataloader
) )
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess) # shorter in multiprocess)
......
...@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader ...@@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator, DistributedType
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
...@@ -448,6 +448,10 @@ def main(): ...@@ -448,6 +448,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader model, optimizer, train_dataloader, eval_dataloader
) )
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess) # shorter in multiprocess)
......
...@@ -594,6 +594,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -594,6 +594,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self = getattr(self, self.base_model_prefix) self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
for module in self.modules():
if hasattr(module, "_tie_weights"):
module._tie_weights()
@staticmethod @staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
uninitialized_encoder_weights: List[str] = [] uninitialized_encoder_weights: List[str] = []
......
...@@ -860,8 +860,6 @@ class AlbertMLMHead(nn.Module): ...@@ -860,8 +860,6 @@ class AlbertMLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.embedding_size) self.dense = nn.Linear(config.hidden_size, config.embedding_size)
self.decoder = nn.Linear(config.embedding_size, config.vocab_size) self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -874,6 +872,10 @@ class AlbertMLMHead(nn.Module): ...@@ -874,6 +872,10 @@ class AlbertMLMHead(nn.Module):
return prediction_scores return prediction_scores
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class AlbertSOPHead(nn.Module): class AlbertSOPHead(nn.Module):
def __init__(self, config): def __init__(self, config):
......
...@@ -430,16 +430,18 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -430,16 +430,18 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
class BertGenerationOnlyLMHead(nn.Module): class BertGenerationOnlyLMHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, hidden_states): def forward(self, hidden_states):
logits = self.decoder(hidden_states) logits = self.decoder(hidden_states)
return logits return logits
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """, """BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """,
......
...@@ -948,10 +948,8 @@ class IBertLMHead(nn.Module): ...@@ -948,10 +948,8 @@ class IBertLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, features, **kwargs): def forward(self, features, **kwargs):
...@@ -964,6 +962,10 @@ class IBertLMHead(nn.Module): ...@@ -964,6 +962,10 @@ class IBertLMHead(nn.Module):
return x return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -1336,10 +1336,8 @@ class LongformerLMHead(nn.Module): ...@@ -1336,10 +1336,8 @@ class LongformerLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, features, **kwargs): def forward(self, features, **kwargs):
...@@ -1352,6 +1350,10 @@ class LongformerLMHead(nn.Module): ...@@ -1352,6 +1350,10 @@ class LongformerLMHead(nn.Module):
return x return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class LongformerPreTrainedModel(PreTrainedModel): class LongformerPreTrainedModel(PreTrainedModel):
""" """
......
...@@ -1747,8 +1747,6 @@ class ReformerOnlyLMHead(nn.Module): ...@@ -1747,8 +1747,6 @@ class ReformerOnlyLMHead(nn.Module):
self.chunk_size_lm_head = config.chunk_size_lm_head self.chunk_size_lm_head = config.chunk_size_lm_head
self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -1758,6 +1756,10 @@ class ReformerOnlyLMHead(nn.Module): ...@@ -1758,6 +1756,10 @@ class ReformerOnlyLMHead(nn.Module):
hidden_states = self.decoder(hidden_states) hidden_states = self.decoder(hidden_states)
return hidden_states return hidden_states
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class ReformerPreTrainedModel(PreTrainedModel): class ReformerPreTrainedModel(PreTrainedModel):
""" """
......
...@@ -1124,10 +1124,8 @@ class RobertaLMHead(nn.Module): ...@@ -1124,10 +1124,8 @@ class RobertaLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, features, **kwargs): def forward(self, features, **kwargs):
...@@ -1140,6 +1138,10 @@ class RobertaLMHead(nn.Module): ...@@ -1140,6 +1138,10 @@ class RobertaLMHead(nn.Module):
return x return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -364,7 +364,7 @@ class Trainer: ...@@ -364,7 +364,7 @@ class Trainer:
self.tokenizer = tokenizer self.tokenizer = tokenizer
if self.place_model_on_device: if self.place_model_on_device:
model = model.to(args.device) self._move_model_to_device(model, args.device)
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
if self.is_model_parallel: if self.is_model_parallel:
...@@ -505,6 +505,12 @@ class Trainer: ...@@ -505,6 +505,12 @@ class Trainer:
""" """
self.callback_handler.remove_callback(callback) self.callback_handler.remove_callback(callback)
def _move_model_to_device(self, model, device):
model = model.to(device)
# Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
model.tie_weights()
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns: if not self.args.remove_unused_columns:
return dataset return dataset
...@@ -1017,7 +1023,7 @@ class Trainer: ...@@ -1017,7 +1023,7 @@ class Trainer:
# do_train is not a reliable argument, as it might not be set and .train() still called, so # do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround: # the following is a workaround:
if args.fp16_full_eval and not args.do_train: if args.fp16_full_eval and not args.do_train:
self.model = self.model.to(args.device) self._move_model_to_device(self.model, args.device)
if "model_path" in kwargs: if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path") resume_from_checkpoint = kwargs.pop("model_path")
...@@ -1078,7 +1084,7 @@ class Trainer: ...@@ -1078,7 +1084,7 @@ class Trainer:
# If model was re-initialized, put it on the right device and update self.model_wrapped # If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded: if model_reloaded:
if self.place_model_on_device: if self.place_model_on_device:
self.model = self.model.to(args.device) self._move_model_to_device(self.model, args.device)
self.model_wrapped = self.model self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not # Keeping track whether we can can len() on the dataset or not
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment