Unverified Commit 4cef546f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Add `accelerate` support for BART-like models (#19927)



* forward contrib credits from suggestion

* add `accelerate` support for BART-like models
Co-authored-by: default avatarsgugger <sgugger@users.noreply.github.com>
parent ebfd7229
......@@ -500,6 +500,7 @@ class BartPretrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
......@@ -712,10 +713,10 @@ class BartEncoder(BartPretrainedModel):
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
......@@ -801,6 +802,7 @@ class BartEncoder(BartPretrainedModel):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
......@@ -884,10 +886,10 @@ class BartDecoder(BartPretrainedModel):
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
......@@ -1043,6 +1045,7 @@ class BartDecoder(BartPretrainedModel):
# embed positions
positions = self.embed_positions(input, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
......@@ -1373,7 +1376,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None
if labels is not None:
......
......@@ -1595,6 +1595,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
config_class = BigBirdPegasusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
......@@ -1788,10 +1789,10 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
config.max_position_embeddings,
......@@ -2082,10 +2083,10 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
config.max_position_embeddings,
......@@ -2240,6 +2241,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions
......@@ -2573,7 +2575,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None
if labels is not None:
......
......@@ -506,6 +506,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
config_class = PLBartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
......@@ -683,10 +684,10 @@ class PLBartEncoder(PLBartPreTrainedModel):
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = PLBartLearnedPositionalEmbedding(
config.max_position_embeddings,
......@@ -772,6 +773,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
......@@ -856,10 +858,10 @@ class PLBartDecoder(PLBartPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = PLBartLearnedPositionalEmbedding(
config.max_position_embeddings,
......@@ -1015,6 +1017,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
# embed positions
positions = self.embed_positions(input, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
......@@ -1334,7 +1337,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
lm_logits = self.lm_head(outputs[0])
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
masked_lm_loss = None
if labels is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment