Unverified Commit d2a22984 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[bugfix] Implement `__setstate__` for Column (fixes #4107) (#4174)



* * Workaround for graph data saving/loading compatibility problem in Column class.  There may be more places in DGL with the same issue, due to using Python serialization, instead of a more cohesive, comprehensive strategy.  This is just a local fix.

* Add checking for non-empty states

* Add unit test

* Handle the case of columns without storage
Co-authored-by: default avatarndickson <ndickson@nvidia.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 8b19c287
...@@ -182,14 +182,9 @@ class Column(TensorStorage): ...@@ -182,14 +182,9 @@ class Column(TensorStorage):
index : Tensor index : Tensor
Index tensor Index tensor
""" """
def __init__(self, storage, scheme=None, index=None, device=None, deferred_dtype=None): def __init__(self, storage, *args, **kwargs):
super().__init__(storage) super().__init__(storage)
self.scheme = scheme if scheme else infer_scheme(storage) self._init(*args, **kwargs)
self.index = index
self.device = device
self.deferred_dtype = deferred_dtype
self.pinned_by_dgl = False
self._data_nd = None
def __len__(self): def __len__(self):
"""The number of features (number of rows) in this column.""" """The number of features (number of rows) in this column."""
...@@ -243,6 +238,8 @@ class Column(TensorStorage): ...@@ -243,6 +238,8 @@ class Column(TensorStorage):
def data(self, val): def data(self, val):
"""Update the column data.""" """Update the column data."""
self.index = None self.index = None
self.device = None
self.deferred_dtype = None
self.storage = val self.storage = val
self._data_nd = None # should unpin data if it was pinned. self._data_nd = None # should unpin data if it was pinned.
self.pinned_by_dgl = False self.pinned_by_dgl = False
...@@ -430,8 +427,43 @@ class Column(TensorStorage): ...@@ -430,8 +427,43 @@ class Column(TensorStorage):
def __getstate__(self): def __getstate__(self):
if self.storage is not None: if self.storage is not None:
_ = self.data # evaluate feature slicing # flush any deferred operations
return self.__dict__ _ = self.data
state = self.__dict__.copy()
# data pinning does not get serialized, so we need to remove that from
# the state
state['_data_nd'] = None
state['pinned_by_dgl'] = False
return state
def __setstate__(self, state):
index = None
device = None
if 'storage' in state and state['storage'] is not None:
assert 'index' not in state or state['index'] is None
assert 'device' not in state or state['device'] is None
else:
# we may have a column with only index information, and that is
# valid
index = None if 'index' not in state else state['index']
device = None if 'device' not in state else state['device']
assert 'deferred_dtype' not in state or state['deferred_dtype'] is None
assert 'pinned_by_dgl' not in state or state['pinned_by_dgl'] is False
assert '_data_nd' not in state or state['_data_nd'] is None
self.__dict__ = state
# properly initialize this object
self._init(self.scheme if hasattr(self, 'scheme') else None,
index=index,
device=device)
def _init(self, scheme=None, index=None, device=None, deferred_dtype=None):
self.scheme = scheme if scheme else infer_scheme(self.storage)
self.index = index
self.device = device
self.deferred_dtype = deferred_dtype
self.pinned_by_dgl = False
self._data_nd = None
def __copy__(self): def __copy__(self):
return self.clone() return self.clone()
......
...@@ -4,6 +4,7 @@ from dgl.frame import Column ...@@ -4,6 +4,7 @@ from dgl.frame import Column
import numpy as np import numpy as np
import backend as F import backend as F
import unittest import unittest
import pickle
from test_utils import parametrize_idtype from test_utils import parametrize_idtype
def test_column_subcolumn(): def test_column_subcolumn():
...@@ -14,7 +15,7 @@ def test_column_subcolumn(): ...@@ -14,7 +15,7 @@ def test_column_subcolumn():
[0., 2., 4., 0.]]), F.ctx()) [0., 2., 4., 0.]]), F.ctx())
original = Column(data) original = Column(data)
# subcolumn from cpu context # subcolumn from cpu context
i1 = F.tensor([0, 2, 1, 3], dtype=F.int64) i1 = F.tensor([0, 2, 1, 3], dtype=F.int64)
l1 = original.subcolumn(i1) l1 = original.subcolumn(i1)
...@@ -37,3 +38,47 @@ def test_column_subcolumn(): ...@@ -37,3 +38,47 @@ def test_column_subcolumn():
i1i2i3 = F.copy_to(F.gather_row(i1i2, F.copy_to(i3, F.context(i1i2))), F.ctx()) i1i2i3 = F.copy_to(F.gather_row(i1i2, F.copy_to(i3, F.context(i1i2))), F.ctx())
assert F.array_equal(l3.data, F.gather_row(data, i1i2i3)) assert F.array_equal(l3.data, F.gather_row(data, i1i2i3))
def test_serialize_deserialize_plain():
data = F.copy_to(F.tensor([[1., 1., 1., 1.],
[0., 2., 9., 0.],
[3., 2., 1., 0.],
[1., 1., 1., 1.],
[0., 2., 4., 0.]]), F.ctx())
original = Column(data)
serial = pickle.dumps(original)
new = pickle.loads(serial)
print("new = {}".format(new))
assert F.array_equal(new.data, original.data)
def test_serialize_deserialize_subcolumn():
data = F.copy_to(F.tensor([[1., 1., 1., 1.],
[0., 2., 9., 0.],
[3., 2., 1., 0.],
[1., 1., 1., 1.],
[0., 2., 4., 0.]]), F.ctx())
original = Column(data)
# subcolumn from cpu context
i1 = F.tensor([0, 2, 1, 3], dtype=F.int64)
l1 = original.subcolumn(i1)
serial = pickle.dumps(l1)
new = pickle.loads(serial)
assert F.array_equal(new.data, l1.data)
def test_serialize_deserialize_dtype():
data = F.copy_to(F.tensor([[1., 1., 1., 1.],
[0., 2., 9., 0.],
[3., 2., 1., 0.],
[1., 1., 1., 1.],
[0., 2., 4., 0.]]), F.ctx())
original = Column(data)
original = original.astype(F.int64)
serial = pickle.dumps(original)
new = pickle.loads(serial)
assert new.dtype == F.int64
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment