Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3c4506e9
Unverified
Commit
3c4506e9
authored
Apr 29, 2020
by
Minjie Wang
Committed by
GitHub
Apr 29, 2020
Browse files
[Bugfix] Add bool data type to backend. (#1487)
* add bool to F.data_type_dict * add utest * skip bool test for mx
parent
f1e4f378
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
31 additions
and
7 deletions
+31
-7
python/dgl/backend/backend.py
python/dgl/backend/backend.py
+1
-0
python/dgl/backend/mxnet/tensor.py
python/dgl/backend/mxnet/tensor.py
+2
-1
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+2
-1
python/dgl/backend/tensorflow/tensor.py
python/dgl/backend/tensorflow/tensor.py
+2
-1
tests/compute/test_frame.py
tests/compute/test_frame.py
+24
-4
No files found.
python/dgl/backend/backend.py
View file @
3c4506e9
...
@@ -28,6 +28,7 @@ def data_type_dict():
...
@@ -28,6 +28,7 @@ def data_type_dict():
int16
int16
int32
int32
int64
int64
bool
This function will be called only *once* during the initialization fo the
This function will be called only *once* during the initialization fo the
backend module. The returned dictionary will become the attributes of the
backend module. The returned dictionary will become the attributes of the
...
...
python/dgl/backend/mxnet/tensor.py
View file @
3c4506e9
...
@@ -28,7 +28,8 @@ def data_type_dict():
...
@@ -28,7 +28,8 @@ def data_type_dict():
'int8'
:
np
.
int8
,
'int8'
:
np
.
int8
,
'int16'
:
np
.
int16
,
'int16'
:
np
.
int16
,
'int32'
:
np
.
int32
,
'int32'
:
np
.
int32
,
'int64'
:
np
.
int64
}
'int64'
:
np
.
int64
,
'bool'
:
np
.
bool
}
def
cpu
():
def
cpu
():
return
mx
.
cpu
()
return
mx
.
cpu
()
...
...
python/dgl/backend/pytorch/tensor.py
View file @
3c4506e9
...
@@ -24,7 +24,8 @@ def data_type_dict():
...
@@ -24,7 +24,8 @@ def data_type_dict():
'int8'
:
th
.
int8
,
'int8'
:
th
.
int8
,
'int16'
:
th
.
int16
,
'int16'
:
th
.
int16
,
'int32'
:
th
.
int32
,
'int32'
:
th
.
int32
,
'int64'
:
th
.
int64
}
'int64'
:
th
.
int64
,
'bool'
:
th
.
bool
}
def
cpu
():
def
cpu
():
return
th
.
device
(
'cpu'
)
return
th
.
device
(
'cpu'
)
...
...
python/dgl/backend/tensorflow/tensor.py
View file @
3c4506e9
...
@@ -50,7 +50,8 @@ def data_type_dict():
...
@@ -50,7 +50,8 @@ def data_type_dict():
'int8'
:
tf
.
int8
,
'int8'
:
tf
.
int8
,
'int16'
:
tf
.
int16
,
'int16'
:
tf
.
int16
,
'int32'
:
tf
.
int32
,
'int32'
:
tf
.
int32
,
'int64'
:
tf
.
int64
}
'int64'
:
tf
.
int64
,
'bool'
:
tf
.
bool
}
def
cpu
():
def
cpu
():
return
"/cpu:0"
return
"/cpu:0"
...
...
tests/compute/test_frame.py
View file @
3c4506e9
...
@@ -4,6 +4,9 @@ from dgl.utils import Index, toindex
...
@@ -4,6 +4,9 @@ from dgl.utils import Index, toindex
import
backend
as
F
import
backend
as
F
import
dgl
import
dgl
import
unittest
import
unittest
import
pickle
import
pytest
import
io
N
=
10
N
=
10
D
=
5
D
=
5
...
@@ -15,10 +18,10 @@ def check_fail(fn):
...
@@ -15,10 +18,10 @@ def check_fail(fn):
except
:
except
:
return
True
return
True
def
create_test_data
(
grad
=
False
):
def
create_test_data
(
grad
=
False
,
dtype
=
F
.
float32
):
c1
=
F
.
randn
((
N
,
D
))
c1
=
F
.
astype
(
F
.
randn
((
N
,
D
))
,
dtype
)
c2
=
F
.
randn
((
N
,
D
))
c2
=
F
.
astype
(
F
.
randn
((
N
,
D
))
,
dtype
)
c3
=
F
.
randn
((
N
,
D
))
c3
=
F
.
astype
(
F
.
randn
((
N
,
D
))
,
dtype
)
if
grad
:
if
grad
:
c1
=
F
.
attach_grad
(
c1
)
c1
=
F
.
attach_grad
(
c1
)
c2
=
F
.
attach_grad
(
c2
)
c2
=
F
.
attach_grad
(
c2
)
...
@@ -357,6 +360,23 @@ def test_inplace():
...
@@ -357,6 +360,23 @@ def test_inplace():
newa2addr
=
id
(
f
[
'a2'
])
newa2addr
=
id
(
f
[
'a2'
])
assert
a2addr
==
newa2addr
assert
a2addr
==
newa2addr
def
_reconstruct_pickle
(
obj
):
f
=
io
.
BytesIO
()
pickle
.
dump
(
obj
,
f
)
f
.
seek
(
0
)
obj
=
pickle
.
load
(
f
)
f
.
close
()
return
obj
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
F
.
float32
,
F
.
int32
]
if
dgl
.
backend
.
backend_name
==
"mxnet"
else
[
F
.
float32
,
F
.
int32
,
F
.
bool
])
def
test_pickle
(
dtype
):
f
=
create_test_data
(
dtype
=
dtype
)
newf
=
_reconstruct_pickle
(
f
)
assert
F
.
array_equal
(
f
[
'a1'
],
newf
[
'a1'
])
assert
F
.
array_equal
(
f
[
'a2'
],
newf
[
'a2'
])
assert
F
.
array_equal
(
f
[
'a3'
],
newf
[
'a3'
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_create
()
test_create
()
test_column1
()
test_column1
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment