Unverified Commit ed15b471 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update (#3916)

parent dc5035b1
...@@ -319,7 +319,7 @@ def main(): ...@@ -319,7 +319,7 @@ def main():
**pipeline_cfg["optimizer"]) **pipeline_cfg["optimizer"])
# train # train
test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss) test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)
torch.save(model, pipeline_cfg["save_path"]) torch.save(model.state_dict(), pipeline_cfg["save_path"])
return test_acc return test_acc
... ...
......
...@@ -112,7 +112,7 @@ def main(): ...@@ -112,7 +112,7 @@ def main():
loss = torch.nn.{{ loss }}() loss = torch.nn.{{ loss }}()
optimizer = torch.optim.Adam(params, **pipeline_cfg["optimizer"]) optimizer = torch.optim.Adam(params, **pipeline_cfg["optimizer"])
test_hits = train(cfg, pipeline_cfg, device, dataset, model, optimizer, loss) test_hits = train(cfg, pipeline_cfg, device, dataset, model, optimizer, loss)
torch.save(model, pipeline_cfg["save_path"]) torch.save(model.state_dict(), pipeline_cfg["save_path"])
return test_hits return test_hits
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -112,7 +112,7 @@ def main(): ...@@ -112,7 +112,7 @@ def main():
optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(model.parameters(), **pipeline_cfg["optimizer"]) optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(model.parameters(), **pipeline_cfg["optimizer"])
# train # train
test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss) test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)
torch.save(model, pipeline_cfg["save_path"]) torch.save(model.state_dict(), pipeline_cfg["save_path"])
return test_acc return test_acc
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -158,7 +158,7 @@ def main(): ...@@ -158,7 +158,7 @@ def main():
loss = torch.nn.{{ user_cfg.general_pipeline.loss }}() loss = torch.nn.{{ user_cfg.general_pipeline.loss }}()
optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(model.parameters(), **pipeline_cfg["optimizer"]) optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(model.parameters(), **pipeline_cfg["optimizer"])
test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss) test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)
torch.save(model, pipeline_cfg["save_path"]) torch.save(model.state_dict(), pipeline_cfg["save_path"])
return test_acc return test_acc
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment