Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7451bb2a
"src/vscode:/vscode.git/clone" did not exist on "ed22b4fd07f9384f4dd45b8de83102cddc536967"
Unverified
Commit
7451bb2a
authored
Jan 01, 2020
by
Da Zheng
Committed by
GitHub
Jan 01, 2020
Browse files
merge eval results in all processes. (#1160)
parent
dfb10db8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
9 deletions
+38
-9
apps/kg/eval.py
apps/kg/eval.py
+14
-1
apps/kg/train.py
apps/kg/train.py
+17
-1
apps/kg/train_mxnet.py
apps/kg/train_mxnet.py
+1
-1
apps/kg/train_pytorch.py
apps/kg/train_pytorch.py
+6
-6
No files found.
apps/kg/eval.py
View file @
7451bb2a
...
...
@@ -139,13 +139,26 @@ def main(args):
args
.
step
=
0
args
.
max_step
=
0
if
args
.
num_proc
>
1
:
queue
=
mp
.
Queue
(
args
.
num_proc
)
procs
=
[]
for
i
in
range
(
args
.
num_proc
):
proc
=
mp
.
Process
(
target
=
test
,
args
=
(
args
,
model
,
[
test_sampler_heads
[
i
],
test_sampler_tails
[
i
]]))
proc
=
mp
.
Process
(
target
=
test
,
args
=
(
args
,
model
,
[
test_sampler_heads
[
i
],
test_sampler_tails
[
i
]],
'Test'
,
queue
))
procs
.
append
(
proc
)
proc
.
start
()
for
proc
in
procs
:
proc
.
join
()
total_metrics
=
{}
for
i
in
range
(
args
.
num_proc
):
metrics
=
queue
.
get
()
for
k
,
v
in
metrics
.
items
():
if
i
==
0
:
total_metrics
[
k
]
=
v
/
args
.
num_proc
else
:
total_metrics
[
k
]
+=
v
/
args
.
num_proc
for
k
,
v
in
metrics
.
items
():
print
(
'Test average {} at [{}/{}]: {}'
.
format
(
k
,
args
.
step
,
args
.
max_step
,
v
))
else
:
test
(
args
,
model
,
[
test_sampler_head
,
test_sampler_tail
])
...
...
apps/kg/train.py
View file @
7451bb2a
...
...
@@ -263,16 +263,32 @@ def run(args, logger):
# test
if
args
.
test
:
start
=
time
.
time
()
if
args
.
num_proc
>
1
:
queue
=
mp
.
Queue
(
args
.
num_proc
)
procs
=
[]
for
i
in
range
(
args
.
num_proc
):
proc
=
mp
.
Process
(
target
=
test
,
args
=
(
args
,
model
,
[
test_sampler_heads
[
i
],
test_sampler_tails
[
i
]]))
proc
=
mp
.
Process
(
target
=
test
,
args
=
(
args
,
model
,
[
test_sampler_heads
[
i
],
test_sampler_tails
[
i
]],
'Test'
,
queue
))
procs
.
append
(
proc
)
proc
.
start
()
total_metrics
=
{}
for
i
in
range
(
args
.
num_proc
):
metrics
=
queue
.
get
()
for
k
,
v
in
metrics
.
items
():
if
i
==
0
:
total_metrics
[
k
]
=
v
/
args
.
num_proc
else
:
total_metrics
[
k
]
+=
v
/
args
.
num_proc
for
k
,
v
in
metrics
.
items
():
print
(
'Test average {} at [{}/{}]: {}'
.
format
(
k
,
args
.
step
,
args
.
max_step
,
v
))
for
proc
in
procs
:
proc
.
join
()
else
:
test
(
args
,
model
,
[
test_sampler_head
,
test_sampler_tail
])
print
(
'test:'
,
time
.
time
()
-
start
)
if
__name__
==
'__main__'
:
args
=
ArgParser
().
parse_args
()
...
...
apps/kg/train_mxnet.py
View file @
7451bb2a
...
...
@@ -61,7 +61,7 @@ def train(args, model, train_sampler, valid_samplers=None):
# clear cache
logs
=
[]
def
test
(
args
,
model
,
test_samplers
,
mode
=
'Test'
):
def
test
(
args
,
model
,
test_samplers
,
mode
=
'Test'
,
queue
=
None
):
logs
=
[]
for
sampler
in
test_samplers
:
...
...
apps/kg/train_pytorch.py
View file @
7451bb2a
...
...
@@ -80,10 +80,9 @@ def train(args, model, train_sampler, valid_samplers=None):
test
(
args
,
model
,
valid_samplers
,
mode
=
'Valid'
)
print
(
'test:'
,
time
.
time
()
-
start
)
def
test
(
args
,
model
,
test_samplers
,
mode
=
'Test'
):
def
test
(
args
,
model
,
test_samplers
,
mode
=
'Test'
,
queue
=
None
):
if
args
.
num_proc
>
1
:
th
.
set_num_threads
(
1
)
start
=
time
.
time
()
with
th
.
no_grad
():
logs
=
[]
for
sampler
in
test_samplers
:
...
...
@@ -96,9 +95,10 @@ def test(args, model, test_samplers, mode='Test'):
if
len
(
logs
)
>
0
:
for
metric
in
logs
[
0
].
keys
():
metrics
[
metric
]
=
sum
([
log
[
metric
]
for
log
in
logs
])
/
len
(
logs
)
if
queue
is
not
None
:
queue
.
put
(
metrics
)
else
:
for
k
,
v
in
metrics
.
items
():
print
(
'{} average {} at [{}/{}]: {}'
.
format
(
mode
,
k
,
args
.
step
,
args
.
max_step
,
v
))
print
(
'test:'
,
time
.
time
()
-
start
)
test_samplers
[
0
]
=
test_samplers
[
0
].
reset
()
test_samplers
[
1
]
=
test_samplers
[
1
].
reset
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment