Commit a80778f4 authored by Francesco's avatar Francesco Committed by Lysandre Debut
Browse files

small refactoring (only esthetic, not functional)

parent 3df1d2d1
......@@ -191,13 +191,14 @@ class PreTrainedEncoderDecoder(nn.Module):
assert(len(os.listdir(save_directory)) == 0) # sanity check
# Create the "encoder" directory inside the output directory and save the encoder into it
if not os.path.exists(os.path.join(save_directory, "encoder")):
os.mkdir(os.path.join(save_directory, "encoder"))
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
# Create the "encoder" directory inside the output directory and save the decoder into it
if not os.path.exists(os.path.join(save_directory, "decoder")):
os.mkdir(os.path.join(save_directory, "decoder"))
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment