Unverified Commit ed15b471 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update (#3916)

parent dc5035b1
......@@ -319,7 +319,7 @@ def main():
**pipeline_cfg["optimizer"])
# train
test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)
torch.save(model, pipeline_cfg["save_path"])
torch.save(model.state_dict(), pipeline_cfg["save_path"])
return test_acc
...
......
......@@ -112,7 +112,7 @@ def main():
loss = torch.nn.{{ loss }}()
optimizer = torch.optim.Adam(params, **pipeline_cfg["optimizer"])
test_hits = train(cfg, pipeline_cfg, device, dataset, model, optimizer, loss)
torch.save(model, pipeline_cfg["save_path"])
torch.save(model.state_dict(), pipeline_cfg["save_path"])
return test_hits
if __name__ == '__main__':
......
......@@ -112,7 +112,7 @@ def main():
optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(model.parameters(), **pipeline_cfg["optimizer"])
# train
test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)
torch.save(model, pipeline_cfg["save_path"])
torch.save(model.state_dict(), pipeline_cfg["save_path"])
return test_acc
if __name__ == '__main__':
......
......@@ -158,7 +158,7 @@ def main():
loss = torch.nn.{{ user_cfg.general_pipeline.loss }}()
optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(model.parameters(), **pipeline_cfg["optimizer"])
test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)
torch.save(model, pipeline_cfg["save_path"])
torch.save(model.state_dict(), pipeline_cfg["save_path"])
return test_acc
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment