Commit b0eaff36 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Added a +1 to epoch when saving weights

parent 611961ad
......@@ -545,10 +545,11 @@ def train_discriminator(
if save_model:
# torch.save(discriminator.state_dict(),
# "{}_discriminator_{}.pt".format(
# args.dataset, epoch
# args.dataset, epoch + 1
# ))
torch.save(discriminator.get_classifier().state_dict(),
"{}_classifier_head_epoch_{}.pt".format(dataset, epoch))
"{}_classifier_head_epoch_{}.pt".format(dataset,
epoch + 1))
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment