Unverified Commit a220f160 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[FlaxBart] make sure no grads are computed an bias (#16345)

* [FlaxBart] make sure no grads are computed an bias

* correct all other seq2seq models
parent 4975002d
...@@ -1292,7 +1292,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module): ...@@ -1292,7 +1292,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype) lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
......
...@@ -1270,7 +1270,7 @@ class FlaxBlenderbotForConditionalGenerationModule(nn.Module): ...@@ -1270,7 +1270,7 @@ class FlaxBlenderbotForConditionalGenerationModule(nn.Module):
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype) lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
......
...@@ -1267,7 +1267,7 @@ class FlaxBlenderbotSmallForConditionalGenerationModule(nn.Module): ...@@ -1267,7 +1267,7 @@ class FlaxBlenderbotSmallForConditionalGenerationModule(nn.Module):
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype) lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
......
...@@ -1329,7 +1329,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module): ...@@ -1329,7 +1329,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module):
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype) lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
......
...@@ -1280,7 +1280,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module): ...@@ -1280,7 +1280,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module):
else: else:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype) lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment