Unverified Commit 4cef546f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Add `accelerate` support for BART-like models (#19927)



* forward contrib credits from suggestion

* add `accelerate` support for BART-like models
Co-authored-by: default avatarsgugger <sgugger@users.noreply.github.com>
parent ebfd7229
...@@ -500,6 +500,7 @@ class BartPretrainedModel(PreTrainedModel): ...@@ -500,6 +500,7 @@ class BartPretrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"] _keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -712,11 +713,11 @@ class BartEncoder(BartPretrainedModel): ...@@ -712,11 +713,11 @@ class BartEncoder(BartPretrainedModel):
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding( self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
embed_dim, embed_dim,
...@@ -801,6 +802,7 @@ class BartEncoder(BartPretrainedModel): ...@@ -801,6 +802,7 @@ class BartEncoder(BartPretrainedModel):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input) embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
...@@ -884,11 +886,11 @@ class BartDecoder(BartPretrainedModel): ...@@ -884,11 +886,11 @@ class BartDecoder(BartPretrainedModel):
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding( self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
...@@ -1043,6 +1045,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -1043,6 +1045,7 @@ class BartDecoder(BartPretrainedModel):
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
...@@ -1373,7 +1376,9 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1373,7 +1376,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
......
...@@ -1595,6 +1595,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): ...@@ -1595,6 +1595,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
config_class = BigBirdPegasusConfig config_class = BigBirdPegasusConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -1788,11 +1789,11 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -1788,11 +1789,11 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
embed_dim, embed_dim,
...@@ -2082,11 +2083,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2082,11 +2083,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
...@@ -2240,6 +2241,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2240,6 +2241,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
...@@ -2573,7 +2575,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2573,7 +2575,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
......
...@@ -506,6 +506,7 @@ class PLBartPreTrainedModel(PreTrainedModel): ...@@ -506,6 +506,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
config_class = PLBartConfig config_class = PLBartConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -683,11 +684,11 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -683,11 +684,11 @@ class PLBartEncoder(PLBartPreTrainedModel):
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = PLBartLearnedPositionalEmbedding( self.embed_positions = PLBartLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
embed_dim, embed_dim,
...@@ -772,6 +773,7 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -772,6 +773,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input) embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
...@@ -856,11 +858,11 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -856,11 +858,11 @@ class PLBartDecoder(PLBartPreTrainedModel):
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = PLBartLearnedPositionalEmbedding( self.embed_positions = PLBartLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
...@@ -1015,6 +1017,7 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -1015,6 +1017,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
...@@ -1334,7 +1337,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1334,7 +1337,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment