"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b43cb09aaa6d81f4e1f4a2537764e37aa823b30b"
Commit 3df1d2d1 authored by Francesco's avatar Francesco Committed by Lysandre Debut
Browse files

- Create the output directory (whose name is passed by the user in the...

- Create the output directory (whose name is passed by the user in the "save_directory" parameter) where it will be saved encoder and decoder, if not exists.
- Empty the output directory, if it contains any files or subdirectories.
- Create the "encoder" directory inside "save_directory", if not exists.
- Create the "decoder" directory inside "save_directory", if not exists.
- Save the encoder and the decoder in the previous two directories, respectively.
parent a436574b
...@@ -166,6 +166,37 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -166,6 +166,37 @@ class PreTrainedEncoderDecoder(nn.Module):
We save the encoder' and decoder's parameters in two separate directories. We save the encoder' and decoder's parameters in two separate directories.
""" """
# If the root output directory does not exist, create it
if not os.path.exists(save_directory):
os.mkdir(save_directory)
# Check whether the output directory is empty or not
sub_directories = [directory for directory in os.listdir(save_directory)
if os.path.isdir(os.path.join(save_directory, directory))]
if len(sub_directories) > 0:
if "encoder" in sub_directories and "decoder" in sub_directories:
print("WARNING: there is an older version of encoder-decoder saved in" +\
" the output directory. The default behaviour is to overwrite them.")
# Empty the output directory
for directory_to_remove in sub_directories:
# Remove all files into the subdirectory
files_to_remove = os.listdir(os.path.join(save_directory, directory_to_remove))
for file_to_remove in files_to_remove:
os.remove(os.path.join(save_directory, directory_to_remove, file_to_remove))
# Remove the subdirectory itself
os.rmdir(os.path.join(save_directory, directory_to_remove))
assert(len(os.listdir(save_directory)) == 0) # sanity check
if not os.path.exists(os.path.join(save_directory, "encoder")):
os.mkdir(os.path.join(save_directory, "encoder"))
if not os.path.exists(os.path.join(save_directory, "decoder")):
os.mkdir(os.path.join(save_directory, "decoder"))
self.encoder.save_pretrained(os.path.join(save_directory, "encoder")) self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
self.decoder.save_pretrained(os.path.join(save_directory, "decoder")) self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment