Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
62b4bbb4
Unverified
Commit
62b4bbb4
authored
Nov 16, 2020
by
mszarma
Committed by
GitHub
Nov 16, 2020
Browse files
[Fix] Enable mini-batch rgcn for CPU (#2345)
parent
77968e30
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
2 deletions
+2
-2
examples/pytorch/rgcn/entity_classify_mp.py
examples/pytorch/rgcn/entity_classify_mp.py
+1
-1
examples/pytorch/rgcn/model.py
examples/pytorch/rgcn/model.py
+1
-1
No files found.
examples/pytorch/rgcn/entity_classify_mp.py
View file @
62b4bbb4
...
@@ -185,7 +185,7 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
...
@@ -185,7 +185,7 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
@
thread_wrapped_func
@
thread_wrapped_func
def
run
(
proc_id
,
n_gpus
,
args
,
devices
,
dataset
,
split
,
queue
=
None
):
def
run
(
proc_id
,
n_gpus
,
args
,
devices
,
dataset
,
split
,
queue
=
None
):
dev_id
=
devices
[
proc_id
]
dev_id
=
devices
[
proc_id
]
if
devices
[
proc_id
]
!=
'cpu'
else
-
1
g
,
node_feats
,
num_of_ntype
,
num_classes
,
num_rels
,
target_idx
,
\
g
,
node_feats
,
num_of_ntype
,
num_classes
,
num_rels
,
target_idx
,
\
train_idx
,
val_idx
,
test_idx
,
labels
=
dataset
train_idx
,
val_idx
,
test_idx
,
labels
=
dataset
if
split
is
not
None
:
if
split
is
not
None
:
...
...
examples/pytorch/rgcn/model.py
View file @
62b4bbb4
...
@@ -78,7 +78,7 @@ class RelGraphEmbedLayer(nn.Module):
...
@@ -78,7 +78,7 @@ class RelGraphEmbedLayer(nn.Module):
sparse_emb
=
False
,
sparse_emb
=
False
,
embed_name
=
'embed'
):
embed_name
=
'embed'
):
super
(
RelGraphEmbedLayer
,
self
).
__init__
()
super
(
RelGraphEmbedLayer
,
self
).
__init__
()
self
.
dev_id
=
dev_id
self
.
dev_id
=
th
.
device
(
dev_id
if
dev_id
>=
0
else
'cpu'
)
self
.
embed_size
=
embed_size
self
.
embed_size
=
embed_size
self
.
embed_name
=
embed_name
self
.
embed_name
=
embed_name
self
.
num_nodes
=
num_nodes
self
.
num_nodes
=
num_nodes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment