Unverified Commit da1af21d authored by andreeahedes's avatar andreeahedes Committed by GitHub
Browse files

PegasusX add _no_split_modules (#25933)

* no_split_modules

* no_split_modules

* inputs_embeds+pos same device

* update _no_split_modules

* update _no_split_modules
parent 70a98024
...@@ -769,6 +769,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): ...@@ -769,6 +769,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
config_class = PegasusXConfig config_class = PegasusXConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -1299,6 +1300,8 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1299,6 +1300,8 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
# embed positions # embed positions
positions = self.embed_positions(inputs_embeds, past_key_values_length) positions = self.embed_positions(inputs_embeds, past_key_values_length)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment