Unverified Commit 3c4506e9 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] Add bool data type to backend. (#1487)

* add bool to F.data_type_dict

* add utest

* skip bool test for mx
parent f1e4f378
...@@ -28,6 +28,7 @@ def data_type_dict(): ...@@ -28,6 +28,7 @@ def data_type_dict():
int16 int16
int32 int32
int64 int64
bool
This function will be called only *once* during the initialization fo the This function will be called only *once* during the initialization fo the
backend module. The returned dictionary will become the attributes of the backend module. The returned dictionary will become the attributes of the
......
...@@ -28,7 +28,8 @@ def data_type_dict(): ...@@ -28,7 +28,8 @@ def data_type_dict():
'int8' : np.int8, 'int8' : np.int8,
'int16' : np.int16, 'int16' : np.int16,
'int32' : np.int32, 'int32' : np.int32,
'int64' : np.int64} 'int64' : np.int64,
'bool' : np.bool}
def cpu(): def cpu():
return mx.cpu() return mx.cpu()
......
...@@ -24,7 +24,8 @@ def data_type_dict(): ...@@ -24,7 +24,8 @@ def data_type_dict():
'int8' : th.int8, 'int8' : th.int8,
'int16' : th.int16, 'int16' : th.int16,
'int32' : th.int32, 'int32' : th.int32,
'int64' : th.int64} 'int64' : th.int64,
'bool' : th.bool}
def cpu(): def cpu():
return th.device('cpu') return th.device('cpu')
......
...@@ -50,7 +50,8 @@ def data_type_dict(): ...@@ -50,7 +50,8 @@ def data_type_dict():
'int8': tf.int8, 'int8': tf.int8,
'int16': tf.int16, 'int16': tf.int16,
'int32': tf.int32, 'int32': tf.int32,
'int64': tf.int64} 'int64': tf.int64,
'bool' : tf.bool}
def cpu(): def cpu():
return "/cpu:0" return "/cpu:0"
......
...@@ -4,6 +4,9 @@ from dgl.utils import Index, toindex ...@@ -4,6 +4,9 @@ from dgl.utils import Index, toindex
import backend as F import backend as F
import dgl import dgl
import unittest import unittest
import pickle
import pytest
import io
N = 10 N = 10
D = 5 D = 5
...@@ -15,10 +18,10 @@ def check_fail(fn): ...@@ -15,10 +18,10 @@ def check_fail(fn):
except: except:
return True return True
def create_test_data(grad=False): def create_test_data(grad=False, dtype=F.float32):
c1 = F.randn((N, D)) c1 = F.astype(F.randn((N, D)), dtype)
c2 = F.randn((N, D)) c2 = F.astype(F.randn((N, D)), dtype)
c3 = F.randn((N, D)) c3 = F.astype(F.randn((N, D)), dtype)
if grad: if grad:
c1 = F.attach_grad(c1) c1 = F.attach_grad(c1)
c2 = F.attach_grad(c2) c2 = F.attach_grad(c2)
...@@ -357,6 +360,23 @@ def test_inplace(): ...@@ -357,6 +360,23 @@ def test_inplace():
newa2addr = id(f['a2']) newa2addr = id(f['a2'])
assert a2addr == newa2addr assert a2addr == newa2addr
def _reconstruct_pickle(obj):
f = io.BytesIO()
pickle.dump(obj, f)
f.seek(0)
obj = pickle.load(f)
f.close()
return obj
@pytest.mark.parametrize('dtype',
[F.float32, F.int32] if dgl.backend.backend_name == "mxnet" else [F.float32, F.int32, F.bool])
def test_pickle(dtype):
f = create_test_data(dtype=dtype)
newf = _reconstruct_pickle(f)
assert F.array_equal(f['a1'], newf['a1'])
assert F.array_equal(f['a2'], newf['a2'])
assert F.array_equal(f['a3'], newf['a3'])
if __name__ == '__main__': if __name__ == '__main__':
test_create() test_create()
test_column1() test_column1()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment