"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7533e3d7e6c6dd4c9c45d8330d6cffeca8bdb08f"
Unverified Commit a8ce0957 authored by Pein's avatar Pein Committed by GitHub
Browse files

update the code of examples/pytorch/metapath2vec.py (#2463)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 4464b973
...@@ -36,11 +36,11 @@ class Metapath2VecTrainer: ...@@ -36,11 +36,11 @@ class Metapath2VecTrainer:
def train(self): def train(self):
optimizer = optim.SparseAdam(self.skip_gram_model.parameters(), lr=self.initial_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(self.dataloader))
for iteration in range(self.iterations): for iteration in range(self.iterations):
print("\n\n\nIteration: " + str(iteration + 1)) print("\n\n\nIteration: " + str(iteration + 1))
optimizer = optim.SparseAdam(self.skip_gram_model.parameters(), lr=self.initial_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(self.dataloader))
running_loss = 0.0 running_loss = 0.0
for i, sample_batched in enumerate(tqdm(self.dataloader)): for i, sample_batched in enumerate(tqdm(self.dataloader)):
...@@ -59,7 +59,7 @@ class Metapath2VecTrainer: ...@@ -59,7 +59,7 @@ class Metapath2VecTrainer:
if i > 0 and i % 500 == 0: if i > 0 and i % 500 == 0:
print(" Loss: " + str(running_loss)) print(" Loss: " + str(running_loss))
self.skip_gram_model.save_embedding(self.data.id2word, self.output_file_name) self.skip_gram_model.save_embedding(self.data.id2word, self.output_file_name)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment