Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
614acf2c
Unverified
Commit
614acf2c
authored
Mar 08, 2020
by
Da Zheng
Committed by
GitHub
Mar 08, 2020
Browse files
[KG] save config when saving the model (#1336)
* save config. * save more.
parent
7ee77f72
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
1 deletion
+21
-1
apps/kg/models/general_models.py
apps/kg/models/general_models.py
+1
-1
apps/kg/train.py
apps/kg/train.py
+20
-0
No files found.
apps/kg/models/general_models.py
View file @
614acf2c
...
@@ -494,4 +494,4 @@ class KEModel(object):
...
@@ -494,4 +494,4 @@ class KEModel(object):
client
.
push
(
name
=
'relation_emb'
,
id_tensor
=
relation_id
,
data_tensor
=
grad
)
client
.
push
(
name
=
'relation_emb'
,
id_tensor
=
relation_id
,
data_tensor
=
grad
)
self
.
entity_emb
.
trace
=
[]
self
.
entity_emb
.
trace
=
[]
self
.
relation_emb
.
trace
=
[]
self
.
relation_emb
.
trace
=
[]
\ No newline at end of file
apps/kg/train.py
View file @
614acf2c
...
@@ -5,6 +5,7 @@ import argparse
...
@@ -5,6 +5,7 @@ import argparse
import
os
import
os
import
logging
import
logging
import
time
import
time
import
json
backend
=
os
.
environ
.
get
(
'DGLBACKEND'
,
'pytorch'
)
backend
=
os
.
environ
.
get
(
'DGLBACKEND'
,
'pytorch'
)
if
backend
.
lower
()
==
'mxnet'
:
if
backend
.
lower
()
==
'mxnet'
:
...
@@ -365,6 +366,25 @@ def run(args, logger):
...
@@ -365,6 +366,25 @@ def run(args, logger):
os
.
mkdir
(
args
.
save_emb
)
os
.
mkdir
(
args
.
save_emb
)
model
.
save_emb
(
args
.
save_emb
,
args
.
dataset
)
model
.
save_emb
(
args
.
save_emb
,
args
.
dataset
)
# We need to save the model configurations as well.
conf_file
=
os
.
path
.
join
(
args
.
save_emb
,
'config.json'
)
with
open
(
conf_file
,
'w'
)
as
outfile
:
json
.
dump
({
'dataset'
:
args
.
dataset
,
'model'
:
args
.
model_name
,
'emb_size'
:
args
.
hidden_dim
,
'max_train_step'
:
args
.
max_step
,
'batch_size'
:
args
.
batch_size
,
'neg_sample_size'
:
args
.
neg_sample_size
,
'lr'
:
args
.
lr
,
'gamma'
:
args
.
gamma
,
'double_ent'
:
args
.
double_ent
,
'double_rel'
:
args
.
double_rel
,
'neg_adversarial_sampling'
:
args
.
neg_adversarial_sampling
,
'adversarial_temperature'
:
args
.
adversarial_temperature
,
'regularization_coef'
:
args
.
regularization_coef
,
'regularization_norm'
:
args
.
regularization_norm
},
outfile
,
indent
=
4
)
# test
# test
if
args
.
test
:
if
args
.
test
:
start
=
time
.
time
()
start
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment