Unverified Commit a220f160 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[FlaxBart] make sure no grads are computed an bias (#16345)

* [FlaxBart] make sure no grads are computed an bias

* correct all other seq2seq models
parent 4975002d
......@@ -1292,7 +1292,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype)
lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict:
output = (lm_logits,) + outputs[1:]
......
......@@ -1270,7 +1270,7 @@ class FlaxBlenderbotForConditionalGenerationModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype)
lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict:
output = (lm_logits,) + outputs[1:]
......
......@@ -1267,7 +1267,7 @@ class FlaxBlenderbotSmallForConditionalGenerationModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype)
lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict:
output = (lm_logits,) + outputs[1:]
......
......@@ -1329,7 +1329,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype)
lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict:
output = (lm_logits,) + outputs[1:]
......
......@@ -1280,7 +1280,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype)
lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict:
output = (lm_logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment