"git@developer.sourcefind.cn:change/sglang.git" did not exist on "4a0d19198bf9222edcb9879028990b481f8ffe56"
Unverified Commit 614acf2c authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[KG] save config when saving the model (#1336)

* save config.

* save more.
parent 7ee77f72
......@@ -5,6 +5,7 @@ import argparse
import os
import logging
import time
import json
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
......@@ -365,6 +366,25 @@ def run(args, logger):
os.mkdir(args.save_emb)
model.save_emb(args.save_emb, args.dataset)
# We need to save the model configurations as well.
conf_file = os.path.join(args.save_emb, 'config.json')
with open(conf_file, 'w') as outfile:
json.dump({'dataset': args.dataset,
'model': args.model_name,
'emb_size': args.hidden_dim,
'max_train_step': args.max_step,
'batch_size': args.batch_size,
'neg_sample_size': args.neg_sample_size,
'lr': args.lr,
'gamma': args.gamma,
'double_ent': args.double_ent,
'double_rel': args.double_rel,
'neg_adversarial_sampling': args.neg_adversarial_sampling,
'adversarial_temperature': args.adversarial_temperature,
'regularization_coef': args.regularization_coef,
'regularization_norm': args.regularization_norm},
outfile, indent=4)
# test
if args.test:
start = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment