Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
dba36c87
"docs/vscode:/vscode.git/clone" did not exist on "6f9ae8d61eca4a2841bb06b47a993c523de6f43c"
Unverified
Commit
dba36c87
authored
Dec 16, 2018
by
Da Zheng
Committed by
GitHub
Dec 16, 2018
Browse files
Fix some minor problems in SSE (#309)
* fix SSE. * fix. * fix.
parent
0d0f4436
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
14 deletions
+20
-14
examples/mxnet/sse/README.md
examples/mxnet/sse/README.md
+4
-3
examples/mxnet/sse/sse_batch.py
examples/mxnet/sse/sse_batch.py
+16
-11
No files found.
examples/mxnet/sse/README.md
View file @
dba36c87
...
...
@@ -25,10 +25,11 @@ Test convergence
```
bash
DGLBACKEND
=
mxnet python3 sse_batch.py
--dataset
"pubmed"
\
--n-epochs
100
\
--n-epochs
100
0
\
--lr
0.001
\
--batch-size
1024
\
--batch-size
30
\
--dgl
\
--use-spmv
\
--neigh-expand
4
--neigh-expand
8
\
--n-hidden
16
```
examples/mxnet/sse/sse_batch.py
View file @
dba36c87
...
...
@@ -200,9 +200,14 @@ def main(args, data):
labels
=
data
.
labels
else
:
labels
=
mx
.
nd
.
array
(
data
.
labels
)
if
data
.
train_mask
is
not
None
:
train_vs
=
mx
.
nd
.
array
(
np
.
nonzero
(
data
.
train_mask
)[
0
],
dtype
=
'int64'
)
eval_vs
=
mx
.
nd
.
array
(
np
.
nonzero
(
data
.
train_mask
==
0
)[
0
],
dtype
=
'int64'
)
else
:
train_size
=
len
(
labels
)
*
args
.
train_percent
train_vs
=
mx
.
nd
.
arange
(
0
,
train_size
,
dtype
=
'int64'
)
eval_vs
=
mx
.
nd
.
arange
(
train_size
,
len
(
labels
),
dtype
=
'int64'
)
print
(
"train size: "
+
str
(
len
(
train_vs
)))
print
(
"eval size: "
+
str
(
len
(
eval_vs
)))
eval_labels
=
mx
.
nd
.
take
(
labels
,
eval_vs
)
...
...
@@ -305,9 +310,6 @@ def main(args, data):
+
" subgraphs takes "
+
str
(
end1
-
start1
))
start1
=
end1
if
i
>
num_batches
/
3
:
break
if
args
.
cache_subgraph
:
sampler
.
restart
()
else
:
...
...
@@ -317,10 +319,12 @@ def main(args, data):
seed_nodes
=
train_vs
,
shuffle
=
True
,
return_seed_id
=
True
)
#
prediction.
#
test set accuracy
logits
=
model_infer
(
g
,
eval_vs
)
eval_loss
=
mx
.
nd
.
softmax_cross_entropy
(
logits
,
eval_labels
)
eval_loss
=
eval_loss
.
asnumpy
()[
0
]
y_bar
=
mx
.
nd
.
argmax
(
logits
,
axis
=
1
)
y
=
eval_labels
accuracy
=
mx
.
nd
.
sum
(
y_bar
==
y
)
/
len
(
y
)
accuracy
=
accuracy
.
asnumpy
()[
0
]
# update the inference model.
infer_params
=
model_infer
.
collect_params
()
...
...
@@ -334,8 +338,8 @@ def main(args, data):
rets
.
append
(
all_hidden
)
dur
.
append
(
time
.
time
()
-
t0
)
print
(
"Epoch {:05d} | Train Loss {:.4f} |
Eval Loss
{:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
train_loss
,
eval_loss
,
np
.
mean
(
dur
),
n_edges
/
np
.
mean
(
dur
)
/
1000
))
print
(
"Epoch {:05d} | Train Loss {:.4f} |
Test Accuracy
{:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
train_loss
,
accuracy
,
np
.
mean
(
dur
),
n_edges
/
np
.
mean
(
dur
)
/
1000
))
return
rets
...
...
@@ -361,6 +365,7 @@ class GraphData:
self
.
graph
=
MXNetGraph
(
csr
)
self
.
features
=
mx
.
nd
.
random
.
normal
(
shape
=
(
csr
.
shape
[
0
],
num_feats
))
self
.
labels
=
mx
.
nd
.
floor
(
mx
.
nd
.
random
.
uniform
(
low
=
0
,
high
=
10
,
shape
=
(
csr
.
shape
[
0
])))
self
.
train_mask
=
None
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'GCN'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment