Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
99f726ff
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "9ce79e8b4cd9c87a7f7b5fdacbd7a35ab76c2c22"
Unverified
Commit
99f726ff
authored
Sep 27, 2022
by
Chang Liu
Committed by
GitHub
Sep 27, 2022
Browse files
Fix MNIST examples (#4632)
parent
7f50a6da
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
examples/pytorch/model_zoo/geometric/mnist.py
examples/pytorch/model_zoo/geometric/mnist.py
+2
-2
No files found.
examples/pytorch/model_zoo/geometric/mnist.py
View file @
99f726ff
...
@@ -35,6 +35,8 @@ L, perm = coarsen(A, coarsening_levels)
...
@@ -35,6 +35,8 @@ L, perm = coarsen(A, coarsening_levels)
g_arr
=
[
dgl
.
from_scipy
(
csr
)
for
csr
in
L
]
g_arr
=
[
dgl
.
from_scipy
(
csr
)
for
csr
in
L
]
coordinate_arr
=
get_coordinates
(
g_arr
,
grid_side
,
coarsening_levels
,
perm
)
coordinate_arr
=
get_coordinates
(
g_arr
,
grid_side
,
coarsening_levels
,
perm
)
str_to_torch_dtype
=
{
'float16'
:
torch
.
half
,
'float32'
:
torch
.
float32
,
'float64'
:
torch
.
float64
}
coordinate_arr
=
[
coord
.
to
(
dtype
=
str_to_torch_dtype
[
str
(
A
.
dtype
)])
for
coord
in
coordinate_arr
]
for
g
,
coordinate_arr
in
zip
(
g_arr
,
coordinate_arr
):
for
g
,
coordinate_arr
in
zip
(
g_arr
,
coordinate_arr
):
g
.
ndata
[
'xy'
]
=
coordinate_arr
g
.
ndata
[
'xy'
]
=
coordinate_arr
g
.
apply_edges
(
z2polar
)
g
.
apply_edges
(
z2polar
)
...
@@ -99,8 +101,6 @@ class MoNet(nn.Module):
...
@@ -99,8 +101,6 @@ class MoNet(nn.Module):
u
=
g
.
edata
[
'u'
]
u
=
g
.
edata
[
'u'
]
feat
=
self
.
pool
(
layer
(
g
,
feat
,
u
).
transpose
(
-
1
,
-
2
).
unsqueeze
(
0
))
\
feat
=
self
.
pool
(
layer
(
g
,
feat
,
u
).
transpose
(
-
1
,
-
2
).
unsqueeze
(
0
))
\
.
squeeze
(
0
).
transpose
(
-
1
,
-
2
)
.
squeeze
(
0
).
transpose
(
-
1
,
-
2
)
print
(
feat
.
shape
)
print
(
g_arr
[
-
1
].
batch_size
)
return
self
.
cls
(
self
.
readout
(
g_arr
[
-
1
],
feat
))
return
self
.
cls
(
self
.
readout
(
g_arr
[
-
1
],
feat
))
class
ChebNet
(
nn
.
Module
):
class
ChebNet
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment