Unverified Commit 1967c43e authored by Chi-Liang, Liu's avatar Chi-Liang, Liu Committed by GitHub
Browse files

BartEnocder add set_input_embeddings (#13960)

* BartEnocder add set_input_embeddings

To unify the interface, add set_input_embeddings to BartEncoder.

* BartEnocder add get_input_embeddings
parent 3e04a41a
...@@ -694,6 +694,12 @@ class BartEncoder(BartPretrainedModel): ...@@ -694,6 +694,12 @@ class BartEncoder(BartPretrainedModel):
self.init_weights() self.init_weights()
self.gradient_checkpointing = False self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment