Unverified Commit 4e931a8e authored by Amelie Schreiber's avatar Amelie Schreiber Committed by GitHub
Browse files

Esm checkpointing (#26454)



* Fixed in-place operation error in EsmEmbeddings

* Fixed in-place operation error in EsmEmbeddings again

---------
Co-authored-by: default avatarSchreiber-Finance <amelie.schreiber.finance@gmail.com>
parent 5e11d72d
...@@ -214,7 +214,7 @@ class EsmEmbeddings(nn.Module): ...@@ -214,7 +214,7 @@ class EsmEmbeddings(nn.Module):
# This is analogous to the way that dropout layers scale down outputs during evaluation when not # This is analogous to the way that dropout layers scale down outputs during evaluation when not
# actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
if self.token_dropout: if self.token_dropout:
embeddings.masked_fill_((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
src_lengths = attention_mask.sum(-1) src_lengths = attention_mask.sum(-1)
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
...@@ -224,7 +224,7 @@ class EsmEmbeddings(nn.Module): ...@@ -224,7 +224,7 @@ class EsmEmbeddings(nn.Module):
if self.position_embedding_type == "absolute": if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings embeddings = embeddings + position_embeddings
if self.layer_norm is not None: if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings) embeddings = self.layer_norm(embeddings)
...@@ -399,7 +399,7 @@ class EsmSelfOutput(nn.Module): ...@@ -399,7 +399,7 @@ class EsmSelfOutput(nn.Module):
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states += input_tensor hidden_states = hidden_states + input_tensor
return hidden_states return hidden_states
...@@ -474,7 +474,7 @@ class EsmOutput(nn.Module): ...@@ -474,7 +474,7 @@ class EsmOutput(nn.Module):
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states += input_tensor hidden_states = hidden_states + input_tensor
return hidden_states return hidden_states
...@@ -633,7 +633,7 @@ class EsmEncoder(nn.Module): ...@@ -633,7 +633,7 @@ class EsmEncoder(nn.Module):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[-1],) next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention: if self.config.add_cross_attention:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment