Unverified Commit 5558ce29 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Fix] Convert float64 to float32 when creating tensor (#3751)

* [Fix] Convert float64 to float32 when creating tensor

* refine docstring
parent e424d296
...@@ -77,6 +77,15 @@ def _validate_data_length(data_dict): ...@@ -77,6 +77,15 @@ def _validate_data_length(data_dict):
"All data are required to have same length while some of them does not. Length of data={}".format(str(len_dict))) "All data are required to have same length while some of them does not. Length of data={}".format(str(len_dict)))
def _tensor(data, dtype=None):
"""Float32 is the default dtype for float tensor in DGL
so let's cast float64 into float32 to avoid dtype mismatch.
"""
ret = F.tensor(data, dtype)
if F.dtype(ret) == F.float64:
ret = F.tensor(ret, dtype=F.float32)
return ret
class BaseData: class BaseData:
""" Class of base data which is inherited by Node/Edge/GraphData. Internal use only. """ """ Class of base data which is inherited by Node/Edge/GraphData. Internal use only. """
@staticmethod @staticmethod
...@@ -137,7 +146,7 @@ class NodeData(BaseData): ...@@ -137,7 +146,7 @@ class NodeData(BaseData):
node_dict[graph_id] = {} node_dict[graph_id] = {}
node_dict[graph_id][n_data.type] = {'mapping': {index: i for i, node_dict[graph_id][n_data.type] = {'mapping': {index: i for i,
index in enumerate(ids[u_indices])}, index in enumerate(ids[u_indices])},
'data': {k: F.tensor(v[idx][u_indices]) 'data': {k: _tensor(v[idx][u_indices])
for k, v in n_data.data.items()}} for k, v in n_data.data.items()}}
return node_dict return node_dict
...@@ -187,8 +196,8 @@ class EdgeData(BaseData): ...@@ -187,8 +196,8 @@ class EdgeData(BaseData):
dst_ids = [dst_mapping[index] for index in e_data.dst[idx]] dst_ids = [dst_mapping[index] for index in e_data.dst[idx]]
if graph_id not in edge_dict: if graph_id not in edge_dict:
edge_dict[graph_id] = {} edge_dict[graph_id] = {}
edge_dict[graph_id][e_data.type] = {'edges': (F.tensor(src_ids), F.tensor(dst_ids)), edge_dict[graph_id][e_data.type] = {'edges': (_tensor(src_ids), _tensor(dst_ids)),
'data': {k: F.tensor(v[idx]) 'data': {k: _tensor(v[idx])
for k, v in e_data.data.items()}} for k, v in e_data.data.items()}}
return edge_dict return edge_dict
...@@ -226,7 +235,7 @@ class GraphData(BaseData): ...@@ -226,7 +235,7 @@ class GraphData(BaseData):
{('_V', '_E', '_V'): ([], [])}) {('_V', '_E', '_V'): ([], [])})
for graph_id in graph_ids: for graph_id in graph_ids:
graphs.append(graphs_dict[graph_id]) graphs.append(graphs_dict[graph_id])
data = {k: F.tensor(v) for k, v in graph_data.data.items()} data = {k: _tensor(v) for k, v in graph_data.data.items()}
return graphs, data return graphs, data
......
...@@ -257,7 +257,9 @@ def _test_construct_graphs_homo(): ...@@ -257,7 +257,9 @@ def _test_construct_graphs_homo():
def assert_data(lhs, rhs): def assert_data(lhs, rhs):
for key, value in lhs.items(): for key, value in lhs.items():
assert key in rhs assert key in rhs
assert F.array_equal(F.tensor(value), rhs[key]) assert F.dtype(rhs[key]) != F.float64
assert F.array_equal(
F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key])
assert_data(ndata, g.ndata) assert_data(ndata, g.ndata)
assert_data(edata, g.edata) assert_data(edata, g.edata)
...@@ -314,7 +316,9 @@ def _test_construct_graphs_hetero(): ...@@ -314,7 +316,9 @@ def _test_construct_graphs_hetero():
def assert_data(lhs, rhs): def assert_data(lhs, rhs):
for key, value in lhs.items(): for key, value in lhs.items():
assert key in rhs assert key in rhs
assert F.array_equal(F.tensor(value), rhs[key]) assert F.dtype(rhs[key]) != F.float64
assert F.array_equal(
F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key])
for ntype in g.ntypes: for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes assert g.num_nodes(ntype) == num_nodes
assert_data(ndata_dict[ntype], g.nodes[ntype].data) assert_data(ndata_dict[ntype], g.nodes[ntype].data)
...@@ -364,7 +368,8 @@ def _test_construct_graphs_multiple(): ...@@ -364,7 +368,8 @@ def _test_construct_graphs_multiple():
assert len(graphs) == num_graphs assert len(graphs) == num_graphs
assert len(data_dict) == len(gdata) assert len(data_dict) == len(gdata)
for k, v in data_dict.items(): for k, v in data_dict.items():
assert F.array_equal(F.tensor(gdata[k]), v) assert F.dtype(v) != F.float64
assert F.array_equal(F.tensor(gdata[k], dtype=F.dtype(v)), v)
for i, g in enumerate(graphs): for i, g in enumerate(graphs):
assert g.is_homogeneous assert g.is_homogeneous
assert g.num_nodes() == num_nodes assert g.num_nodes() == num_nodes
...@@ -377,7 +382,9 @@ def _test_construct_graphs_multiple(): ...@@ -377,7 +382,9 @@ def _test_construct_graphs_multiple():
if node: if node:
indices = u_indices[i*size:(i+1)*size] indices = u_indices[i*size:(i+1)*size]
value = value[indices] value = value[indices]
assert F.array_equal(F.tensor(value), rhs[key]) assert F.dtype(rhs[key]) != F.float64
assert F.array_equal(
F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key])
assert_data(ndata, g.ndata, num_nodes, node=True) assert_data(ndata, g.ndata, num_nodes, node=True)
assert_data(edata, g.edata, num_edges) assert_data(edata, g.edata, num_edges)
...@@ -798,13 +805,13 @@ def _test_CSVDataset_single(): ...@@ -798,13 +805,13 @@ def _test_CSVDataset_single():
assert csv_dataset.has_cache() assert csv_dataset.has_cache()
for ntype in g.ntypes: for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes assert g.num_nodes(ntype) == num_nodes
assert F.array_equal(F.tensor(feat_ndata), assert F.array_equal(F.tensor(feat_ndata, dtype=F.float32),
g.nodes[ntype].data['feat']) g.nodes[ntype].data['feat'])
assert np.array_equal(label_ndata, assert np.array_equal(label_ndata,
F.asnumpy(g.nodes[ntype].data['label'])) F.asnumpy(g.nodes[ntype].data['label']))
for etype in g.etypes: for etype in g.etypes:
assert g.num_edges(etype) == num_edges assert g.num_edges(etype) == num_edges
assert F.array_equal(F.tensor(feat_edata), assert F.array_equal(F.tensor(feat_edata, dtype=F.float32),
g.edges[etype].data['feat']) g.edges[etype].data['feat'])
assert np.array_equal(label_edata, assert np.array_equal(label_edata,
F.asnumpy(g.edges[etype].data['label'])) F.asnumpy(g.edges[etype].data['label']))
...@@ -880,21 +887,21 @@ def _test_CSVDataset_multiple(): ...@@ -880,21 +887,21 @@ def _test_CSVDataset_multiple():
assert len(csv_dataset.data) == 2 assert len(csv_dataset.data) == 2
assert 'feat' in csv_dataset.data assert 'feat' in csv_dataset.data
assert 'label' in csv_dataset.data assert 'label' in csv_dataset.data
assert F.array_equal(F.tensor(feat_gdata), assert F.array_equal(F.tensor(feat_gdata, dtype=F.float32),
csv_dataset.data['feat']) csv_dataset.data['feat'])
for i, (g, g_data) in enumerate(csv_dataset): for i, (g, g_data) in enumerate(csv_dataset):
assert not g.is_homogeneous assert not g.is_homogeneous
assert F.asnumpy(g_data['label']) == label_gdata[i] assert F.asnumpy(g_data['label']) == label_gdata[i]
assert F.array_equal(g_data['feat'], F.tensor(feat_gdata[i])) assert F.array_equal(g_data['feat'], F.tensor(feat_gdata[i], dtype=F.float32))
for ntype in g.ntypes: for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes assert g.num_nodes(ntype) == num_nodes
assert F.array_equal(F.tensor(feat_ndata[i*num_nodes:(i+1)*num_nodes]), assert F.array_equal(F.tensor(feat_ndata[i*num_nodes:(i+1)*num_nodes], dtype=F.float32),
g.nodes[ntype].data['feat']) g.nodes[ntype].data['feat'])
assert np.array_equal(label_ndata[i*num_nodes:(i+1)*num_nodes], assert np.array_equal(label_ndata[i*num_nodes:(i+1)*num_nodes],
F.asnumpy(g.nodes[ntype].data['label'])) F.asnumpy(g.nodes[ntype].data['label']))
for etype in g.etypes: for etype in g.etypes:
assert g.num_edges(etype) == num_edges assert g.num_edges(etype) == num_edges
assert F.array_equal(F.tensor(feat_edata[i*num_edges:(i+1)*num_edges]), assert F.array_equal(F.tensor(feat_edata[i*num_edges:(i+1)*num_edges], dtype=F.float32),
g.edges[etype].data['feat']) g.edges[etype].data['feat'])
assert np.array_equal(label_edata[i*num_edges:(i+1)*num_edges], assert np.array_equal(label_edata[i*num_edges:(i+1)*num_edges],
F.asnumpy(g.edges[etype].data['label'])) F.asnumpy(g.edges[etype].data['label']))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment