Unverified Commit 1f728657 authored by dougian's avatar dougian Committed by GitHub
Browse files

[BART] Update encoder and decoder on set_input_embedding (#3501)


Co-authored-by: default avatarIoannis Douratsos <ioannisd@amazon.com>
parent cc598b31
...@@ -805,6 +805,8 @@ class BartModel(PretrainedBartModel): ...@@ -805,6 +805,8 @@ class BartModel(PretrainedBartModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.shared = value self.shared = value
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def get_output_embeddings(self): def get_output_embeddings(self):
return _make_linear_from_emb(self.shared) # make it on the fly return _make_linear_from_emb(self.shared) # make it on the fly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment