Unverified Commit 3c4506e9 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] Add bool data type to backend. (#1487)

* add bool to F.data_type_dict

* add utest

* skip bool test for mx
parent f1e4f378
......@@ -28,6 +28,7 @@ def data_type_dict():
int16
int32
int64
bool
This function will be called only *once* during the initialization fo the
backend module. The returned dictionary will become the attributes of the
......
......@@ -28,7 +28,8 @@ def data_type_dict():
'int8' : np.int8,
'int16' : np.int16,
'int32' : np.int32,
'int64' : np.int64}
'int64' : np.int64,
'bool' : np.bool}
def cpu():
return mx.cpu()
......
......@@ -24,7 +24,8 @@ def data_type_dict():
'int8' : th.int8,
'int16' : th.int16,
'int32' : th.int32,
'int64' : th.int64}
'int64' : th.int64,
'bool' : th.bool}
def cpu():
return th.device('cpu')
......
......@@ -50,7 +50,8 @@ def data_type_dict():
'int8': tf.int8,
'int16': tf.int16,
'int32': tf.int32,
'int64': tf.int64}
'int64': tf.int64,
'bool' : tf.bool}
def cpu():
return "/cpu:0"
......
......@@ -4,6 +4,9 @@ from dgl.utils import Index, toindex
import backend as F
import dgl
import unittest
import pickle
import pytest
import io
N = 10
D = 5
......@@ -15,10 +18,10 @@ def check_fail(fn):
except:
return True
def create_test_data(grad=False):
c1 = F.randn((N, D))
c2 = F.randn((N, D))
c3 = F.randn((N, D))
def create_test_data(grad=False, dtype=F.float32):
c1 = F.astype(F.randn((N, D)), dtype)
c2 = F.astype(F.randn((N, D)), dtype)
c3 = F.astype(F.randn((N, D)), dtype)
if grad:
c1 = F.attach_grad(c1)
c2 = F.attach_grad(c2)
......@@ -357,6 +360,23 @@ def test_inplace():
newa2addr = id(f['a2'])
assert a2addr == newa2addr
def _reconstruct_pickle(obj):
f = io.BytesIO()
pickle.dump(obj, f)
f.seek(0)
obj = pickle.load(f)
f.close()
return obj
@pytest.mark.parametrize('dtype',
[F.float32, F.int32] if dgl.backend.backend_name == "mxnet" else [F.float32, F.int32, F.bool])
def test_pickle(dtype):
f = create_test_data(dtype=dtype)
newf = _reconstruct_pickle(f)
assert F.array_equal(f['a1'], newf['a1'])
assert F.array_equal(f['a2'], newf['a2'])
assert F.array_equal(f['a3'], newf['a3'])
if __name__ == '__main__':
test_create()
test_column1()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment