frame.py 25.1 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
"""Columnar storage for DGLGraph."""
2
3
from __future__ import absolute_import

4
from collections import MutableMapping, namedtuple
Minjie Wang's avatar
Minjie Wang committed
5
6
import numpy as np

Minjie Wang's avatar
Minjie Wang committed
7
from . import backend as F
Minjie Wang's avatar
Minjie Wang committed
8
from .base import DGLError, dgl_warning
Minjie Wang's avatar
Minjie Wang committed
9
from . import utils
10

11
12

class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
Minjie Wang's avatar
Minjie Wang committed
13
14
15
16
17
18
19
20
21
    """The column scheme.

    Parameters
    ----------
    shape : tuple of int
        The feature shape.
    dtype : TVMType
        The feature data type.
    """
22
    pass
Minjie Wang's avatar
Minjie Wang committed
23

24
def infer_scheme(tensor):
25
    return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor))
Minjie Wang's avatar
Minjie Wang committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

class Column(object):
    """A column is a compact store of features of multiple nodes/edges.

    Currently, we use one dense tensor to batch all the feature tensors
    together (along the first dimension).

    Parameters
    ----------
    data : Tensor
        The initial data of the column.
    scheme : Scheme, optional
        The scheme of the column. Will be inferred if not provided.
    """
    def __init__(self, data, scheme=None):
        self.data = data
42
        self.scheme = scheme if scheme else infer_scheme(data)
Minjie Wang's avatar
Minjie Wang committed
43
44
45
46
47

    def __len__(self):
        """The column length."""
        return F.shape(self.data)[0]

48
49
50
51
    @property
    def shape(self):
        return self.scheme.shape

Minjie Wang's avatar
Minjie Wang committed
52
53
54
55
56
    def __getitem__(self, idx):
        """Return the feature data given the index.

        Parameters
        ----------
57
        idx : slice or utils.Index
Minjie Wang's avatar
Minjie Wang committed
58
59
60
61
62
63
64
            The index.

        Returns
        -------
        Tensor
            The feature data
        """
65
66
67
        if isinstance(idx, slice):
            return self.data[idx]
        else:
68
            user_idx = idx.tousertensor(F.context(self.data))
69
            return F.gather_row(self.data, user_idx)
Minjie Wang's avatar
Minjie Wang committed
70
71
72
73
74
75
76
77
78

    def __setitem__(self, idx, feats):
        """Update the feature data given the index.

        The update is performed out-placely so it can be used in autograd mode.
        For inplace write, please use ``update``.

        Parameters
        ----------
79
        idx : utils.Index or slice
Minjie Wang's avatar
Minjie Wang committed
80
81
82
83
84
85
86
87
88
89
90
            The index.
        feats : Tensor
            The new features.
        """
        self.update(idx, feats, inplace=False)

    def update(self, idx, feats, inplace):
        """Update the feature data given the index.

        Parameters
        ----------
91
        idx : utils.Index or slice
Minjie Wang's avatar
Minjie Wang committed
92
93
94
95
96
97
            The index.
        feats : Tensor
            The new features.
        inplace : bool
            If true, use inplace write.
        """
98
        feat_scheme = infer_scheme(feats)
Minjie Wang's avatar
Minjie Wang committed
99
100
101
        if feat_scheme != self.scheme:
            raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
                    % (feat_scheme, self.scheme))
102
103

        if isinstance(idx, utils.Index):
104
            idx = idx.tousertensor(F.context(self.data))
105

Minjie Wang's avatar
Minjie Wang committed
106
        if inplace:
107
            F.scatter_row_inplace(self.data, idx, feats)
Minjie Wang's avatar
Minjie Wang committed
108
        else:
109
110
            if isinstance(idx, slice):
                # for contiguous indices pack is usually faster than scatter row
111
112
113
114
                part1 = F.narrow_row(self.data, 0, idx.start)
                part2 = feats
                part3 = F.narrow_row(self.data, idx.stop, len(self))
                self.data = F.cat([part1, part2, part3], dim=0)
115
116
            else:
                self.data = F.scatter_row(self.data, idx, feats)
Minjie Wang's avatar
Minjie Wang committed
117

118
119
    def extend(self, feats, feat_scheme=None):
        """Extend the feature data.
120

121
122
123
124
         Parameters
        ----------
        feats : Tensor
            The new features.
125
126
        feat_scheme : Scheme, optional
            The scheme
127
128
129
130
131
132
133
134
        """
        if feat_scheme is None:
            feat_scheme = Scheme.infer_scheme(feats)

        if feat_scheme != self.scheme:
            raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
                    % (feat_scheme, self.scheme))

135
136
        feats = F.copy_to(feats, F.context(self.data))
        self.data = F.cat([self.data, feats], dim=0)
137

Minjie Wang's avatar
Minjie Wang committed
138
139
140
141
142
143
144
145
    @staticmethod
    def create(data):
        """Create a new column using the given data."""
        if isinstance(data, Column):
            return Column(data.data)
        else:
            return Column(data)

Minjie Wang's avatar
Minjie Wang committed
146
class Frame(MutableMapping):
Minjie Wang's avatar
Minjie Wang committed
147
148
149
150
151
152
153
154
155
156
157
158
159
    """The columnar storage for node/edge features.

    The frame is a dictionary from feature fields to feature columns.
    All columns should have the same number of rows (i.e. the same first dimension).

    Parameters
    ----------
    data : dict-like, optional
        The frame data in dictionary. If the provided data is another frame,
        this frame will NOT share columns with the given frame. So any out-place
        update on one will not reflect to the other. The inplace update will
        be seen by both. This follows the semantic of python's container.
    """
160
161
162
163
164
    def __init__(self, data=None):
        if data is None:
            self._columns = dict()
            self._num_rows = 0
        else:
Minjie Wang's avatar
Minjie Wang committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            # Note that we always create a new column for the given data.
            # This avoids two frames accidentally sharing the same column.
            self._columns = {k : Column.create(v) for k, v in data.items()}
            if len(self._columns) != 0:
                self._num_rows = len(next(iter(self._columns.values())))
            else:
                self._num_rows = 0
            # sanity check
            for name, col in self._columns.items():
                if len(col) != self._num_rows:
                    raise DGLError('Expected all columns to have same # rows (%d), '
                                   'got %d on %r.' % (self._num_rows, len(col), name))
        # Initializer for empty values. Initializer is a callable.
        # If is none, then a warning will be raised
        # in the first call and zero initializer will be used later.
        self._initializer = None

182
183
184
185
    def _warn_and_set_initializer(self):
        dgl_warning('Initializer is not set. Use zero initializer instead.'
                    ' To suppress this warning, use `set_initializer` to'
                    ' explicitly specify which initializer to use.')
186
        self._initializer = lambda shape, dtype: F.zeros(shape, dtype)
187

Minjie Wang's avatar
Minjie Wang committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def set_initializer(self, initializer):
        """Set the initializer for empty values.

        Initializer is a callable that returns a tensor given the shape and data type.

        Parameters
        ----------
        initializer : callable
            The initializer.
        """
        self._initializer = initializer

    @property
    def initializer(self):
        """Return the initializer of this frame."""
        return self._initializer
204
205
206

    @property
    def schemes(self):
Minjie Wang's avatar
Minjie Wang committed
207
208
        """Return a dictionary of column name to column schemes."""
        return {k : col.scheme for k, col in self._columns.items()}
209
210
211

    @property
    def num_columns(self):
Minjie Wang's avatar
Minjie Wang committed
212
        """Return the number of columns in this frame."""
213
214
215
216
        return len(self._columns)

    @property
    def num_rows(self):
Minjie Wang's avatar
Minjie Wang committed
217
        """Return the number of rows in this frame."""
218
219
        return self._num_rows

Minjie Wang's avatar
Minjie Wang committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    def __contains__(self, name):
        """Return true if the given column name exists."""
        return name in self._columns

    def __getitem__(self, name):
        """Return the column of the given name.

        Parameters
        ----------
        name : str
            The column name.

        Returns
        -------
        Column
            The column.
        """
        return self._columns[name]

    def __setitem__(self, name, data):
        """Update the whole column.

        Parameters
        ----------
        name : str
            The column name.
        col : Column or data convertible to Column
            The column data.
        """
        self.update_column(name, data)

    def __delitem__(self, name):
        """Delete the whole column.
        
        Parameters
        ----------
        name : str
            The column name.
        """
        del self._columns[name]
Minjie Wang's avatar
Minjie Wang committed
260
261
        if len(self._columns) == 0:
            self._num_rows = 0
Minjie Wang's avatar
Minjie Wang committed
262

Minjie Wang's avatar
Minjie Wang committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    def add_column(self, name, scheme, ctx):
        """Add a new column to the frame.

        The frame will be initialized by the initializer.

        Parameters
        ----------
        name : str
            The column name.
        scheme : Scheme
            The column scheme.
        ctx : TVMContext
            The column context.
        """
        if name in self:
            dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name)
            return
        if self.num_rows == 0:
            raise DGLError('Cannot add column "%s" using column schemes because'
                           ' number of rows is unknown. Make sure there is at least'
Minjie Wang's avatar
Minjie Wang committed
283
                           ' one column in the frame so number of rows can be inferred.' % name)
Minjie Wang's avatar
Minjie Wang committed
284
        if self.initializer is None:
285
            self._warn_and_set_initializer()
Minjie Wang's avatar
Minjie Wang committed
286
287
        # TODO(minjie): directly init data on the targer device.
        init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype)
288
        init_data = F.copy_to(init_data, ctx)
Minjie Wang's avatar
Minjie Wang committed
289
290
291
292
293
294
295
296
297
298
299
300
301
        self._columns[name] = Column(init_data, scheme)

    def update_column(self, name, data):
        """Add or replace the column with the given name and data.

        Parameters
        ----------
        name : str
            The column name.
        data : Column or data convertible to Column
            The column data.
        """
        col = Column.create(data)
302
        if self.num_columns == 0:
Minjie Wang's avatar
Minjie Wang committed
303
304
305
306
            self._num_rows = len(col)
        elif len(col) != self._num_rows:
            raise DGLError('Expected data to have %d rows, got %d.' %
                           (self._num_rows, len(col)))
307
308
309
        self._columns[name] = col

    def append(self, other):
Minjie Wang's avatar
Minjie Wang committed
310
311
312
313
314
315
316
317
318
319
320
321
322
        """Append another frame's data into this frame.

        If the current frame is empty, it will just use the columns of the
        given frame. Otherwise, the given data should contain all the
        column keys of this frame.

        Parameters
        ----------
        other : Frame or dict-like
            The frame data to be appended.
        """
        if not isinstance(other, Frame):
            other = Frame(other)
323
        if len(self._columns) == 0:
Minjie Wang's avatar
Minjie Wang committed
324
325
            for key, col in other.items():
                self._columns[key] = col
Minjie Wang's avatar
Minjie Wang committed
326
            self._num_rows = other.num_rows
327
        else:
Minjie Wang's avatar
Minjie Wang committed
328
            for key, col in other.items():
329
                self._columns[key].extend(col.data, col.scheme)
Minjie Wang's avatar
Minjie Wang committed
330
            self._num_rows += other.num_rows
331
332

    def clear(self):
Minjie Wang's avatar
Minjie Wang committed
333
        """Clear this frame. Remove all the columns."""
334
335
336
        self._columns = {}
        self._num_rows = 0

Minjie Wang's avatar
Minjie Wang committed
337
    def __iter__(self):
Minjie Wang's avatar
Minjie Wang committed
338
        """Return an iterator of columns."""
Minjie Wang's avatar
Minjie Wang committed
339
340
341
        return iter(self._columns)

    def __len__(self):
Minjie Wang's avatar
Minjie Wang committed
342
        """Return the number of columns."""
Minjie Wang's avatar
Minjie Wang committed
343
344
        return self.num_columns

Minjie Wang's avatar
Minjie Wang committed
345
346
347
348
    def keys(self):
        """Return the keys."""
        return self._columns.keys()

Minjie Wang's avatar
Minjie Wang committed
349
class FrameRef(MutableMapping):
Minjie Wang's avatar
Minjie Wang committed
350
    """Reference object to a frame on a subset of rows.
Minjie Wang's avatar
Minjie Wang committed
351
352
353

    Parameters
    ----------
Minjie Wang's avatar
Minjie Wang committed
354
355
356
    frame : Frame, optional
        The underlying frame. If not given, the reference will point to a
        new empty frame.
357
    index : iterable, slice, or int, optional
Minjie Wang's avatar
Minjie Wang committed
358
359
360
        The rows that are referenced in the underlying frame. If not given,
        the whole frame is referenced. The index should be distinct (no
        duplication is allowed).
361
362

        Note that if a slice is given, the step must be None.
Minjie Wang's avatar
Minjie Wang committed
363
    """
Minjie Wang's avatar
Minjie Wang committed
364
365
366
    def __init__(self, frame=None, index=None):
        self._frame = frame if frame is not None else Frame()
        if index is None:
367
            # _index_data can be either a slice or an iterable
Minjie Wang's avatar
Minjie Wang committed
368
            self._index_data = slice(0, self._frame.num_rows)
Minjie Wang's avatar
Minjie Wang committed
369
        else:
Minjie Wang's avatar
Minjie Wang committed
370
            # TODO(minjie): check no duplication
Minjie Wang's avatar
Minjie Wang committed
371
372
            self._index_data = index
        self._index = None
373
        self._index_or_slice = None
Minjie Wang's avatar
Minjie Wang committed
374
375
376

    @property
    def schemes(self):
Minjie Wang's avatar
Minjie Wang committed
377
378
379
380
381
382
383
        """Return the frame schemes.
        
        Returns
        -------
        dict of str to Scheme
            The frame schemes.
        """
Minjie Wang's avatar
Minjie Wang committed
384
385
386
387
        return self._frame.schemes

    @property
    def num_columns(self):
Minjie Wang's avatar
Minjie Wang committed
388
        """Return the number of columns in the referred frame."""
Minjie Wang's avatar
Minjie Wang committed
389
390
391
392
        return self._frame.num_columns

    @property
    def num_rows(self):
Minjie Wang's avatar
Minjie Wang committed
393
        """Return the number of rows referred."""
Minjie Wang's avatar
Minjie Wang committed
394
        if isinstance(self._index_data, slice):
395
396
            # NOTE: we always assume that slice.step is None
            return self._index_data.stop - self._index_data.start
Minjie Wang's avatar
Minjie Wang committed
397
        else:
Minjie Wang's avatar
Minjie Wang committed
398
            return len(self._index_data)
Minjie Wang's avatar
Minjie Wang committed
399

Minjie Wang's avatar
Minjie Wang committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    def set_initializer(self, initializer):
        """Set the initializer for empty values.

        Initializer is a callable that returns a tensor given the shape and data type.

        Parameters
        ----------
        initializer : callable
            The initializer.
        """
        self._frame.set_initializer(initializer)

    def index(self):
        """Return the index object.

        Returns
        -------
        utils.Index
            The index.
        """
        if self._index is None:
            if self.is_contiguous():
                self._index = utils.toindex(
423
424
                        F.arange(self._index_data.start,
                                 self._index_data.stop))
Minjie Wang's avatar
Minjie Wang committed
425
426
427
428
            else:
                self._index = utils.toindex(self._index_data)
        return self._index

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    def index_or_slice(self):
        """Returns the index object or the slice

        Returns
        -------
        utils.Index or slice
            The index or slice
        """
        if self._index_or_slice is None:
            if self.is_contiguous():
                self._index_or_slice = self._index_data
            else:
                self._index_or_slice = utils.toindex(self._index_data)
        return self._index_or_slice

Minjie Wang's avatar
Minjie Wang committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    def __contains__(self, name):
        """Return whether the column name exists."""
        return name in self._frame

    def __iter__(self):
        """Return the iterator of the columns."""
        return iter(self._frame)

    def __len__(self):
        """Return the number of columns."""
        return self.num_columns

    def keys(self):
        """Return the keys."""
        return self._frame.keys()
Minjie Wang's avatar
Minjie Wang committed
459
460

    def __getitem__(self, key):
Minjie Wang's avatar
Minjie Wang committed
461
462
463
        """Get data from the frame.

        If the provided key is string, the corresponding column data will be returned.
464
465
        If the provided key is an index or a slice, the corresponding rows will be selected.
        The returned rows are saved in a lazy dictionary so only the real selection happens
Minjie Wang's avatar
Minjie Wang committed
466
467
468
469
470
471
472
473
474
475
476
477
478
        when the explicit column name is provided.
        
        Examples (using pytorch)
        ------------------------
        >>> # create a frame of two columns and five rows
        >>> f = Frame({'c1' : torch.zeros([5, 2]), 'c2' : torch.ones([5, 2])})
        >>> fr = FrameRef(f)
        >>> # select the row 1 and 2, the returned `rows` is a lazy dictionary.
        >>> rows = fr[Index([1, 2])]
        >>> rows['c1']  # only select rows for 'c1' column; 'c2' column is not sliced.
        
        Parameters
        ----------
479
        key : str or utils.Index or slice
Minjie Wang's avatar
Minjie Wang committed
480
481
482
483
484
485
486
            The key.

        Returns
        -------
        Tensor or lazy dict or tensors
            Depends on whether it is a column selection or row selection.
        """
Minjie Wang's avatar
Minjie Wang committed
487
        if isinstance(key, str):
Minjie Wang's avatar
Minjie Wang committed
488
            return self.select_column(key)
Minjie Wang's avatar
Minjie Wang committed
489
490
491
        else:
            return self.select_rows(key)

Minjie Wang's avatar
Minjie Wang committed
492
493
494
495
496
497
498
499
500
501
    def select_column(self, name):
        """Return the column of the given name.

        If only part of the rows are referenced, the fetching the whole column will
        also slice out the referenced rows.

        Parameters
        ----------
        name : str
            The column name.
Minjie Wang's avatar
Minjie Wang committed
502

Minjie Wang's avatar
Minjie Wang committed
503
504
505
506
507
        Returns
        -------
        Tensor
            The column data.
        """
Minjie Wang's avatar
Minjie Wang committed
508
509
        col = self._frame[name]
        if self.is_span_whole_column():
Minjie Wang's avatar
Minjie Wang committed
510
            return col.data
Minjie Wang's avatar
Minjie Wang committed
511
        else:
512
            return col[self.index_or_slice()]
Minjie Wang's avatar
Minjie Wang committed
513
514
515
516
517
518

    def select_rows(self, query):
        """Return the rows given the query.

        Parameters
        ----------
519
        query : utils.Index or slice
Minjie Wang's avatar
Minjie Wang committed
520
521
522
523
524
525
526
            The rows to be selected.

        Returns
        -------
        utils.LazyDict
            The lazy dictionary from str to the selected data.
        """
527
528
        rows = self._getrows(query)
        return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys())
529

Minjie Wang's avatar
Minjie Wang committed
530
    def __setitem__(self, key, val):
Minjie Wang's avatar
Minjie Wang committed
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        """Update the data in the frame.

        If the provided key is string, the corresponding column data will be updated.
        The provided value should be one tensor that have the same scheme and length
        as the column.

        If the provided key is an index, the corresponding rows will be updated. The
        value provided should be a dictionary of string to the data of each column.

        All updates are performed out-placely to be work with autograd. For inplace
        update, use ``update_column`` or ``update_rows``.

        Parameters
        ----------
        key : str or utils.Index
            The key.
        val : Tensor or dict of tensors
            The value.
        """
Minjie Wang's avatar
Minjie Wang committed
550
        if isinstance(key, str):
Minjie Wang's avatar
Minjie Wang committed
551
            self.update_column(key, val, inplace=False)
Minjie Wang's avatar
Minjie Wang committed
552
        else:
Minjie Wang's avatar
Minjie Wang committed
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
            self.update_rows(key, val, inplace=False)

    def update_column(self, name, data, inplace):
        """Update the column.

        If this frameref spans the whole column of the underlying frame, this is
        equivalent to update the column of the frame.

        If this frameref only points to part of the rows, then update the column
        here will correspond to update part of the column in the frame. Raise error
        if the given column name does not exist.

        Parameters
        ----------
        name : str
            The column name.
        data : Tensor
            The update data.
        inplace : bool
            True if the update is performed inplacely.
        """
Minjie Wang's avatar
Minjie Wang committed
574
        if self.is_span_whole_column():
Minjie Wang's avatar
Minjie Wang committed
575
            col = Column.create(data)
Minjie Wang's avatar
Minjie Wang committed
576
            if self.num_columns == 0:
Minjie Wang's avatar
Minjie Wang committed
577
578
                # the frame is empty
                self._index_data = slice(0, len(col))
Minjie Wang's avatar
Minjie Wang committed
579
580
581
                self._clear_cache()
            self._frame[name] = col
        else:
Minjie Wang's avatar
Minjie Wang committed
582
            if name not in self._frame:
583
                ctx = F.context(data)
584
                self._frame.add_column(name, infer_scheme(data), ctx)
Minjie Wang's avatar
Minjie Wang committed
585
            fcol = self._frame[name]
586
            fcol.update(self.index_or_slice(), data, inplace)
Minjie Wang's avatar
Minjie Wang committed
587

588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
    def add_rows(self, num_rows):
        """Add blank rows.

        For existing fields, the rows will be extended according to their
        initializers.

        Parameters
        ----------
        num_rows : int
            Number of rows to add
        """

        feat_placeholders = {}

        for key in self._frame:
            scheme = self._frame[key].scheme

            if self._frame.initializer is None:
                self._frame._warn_and_set_initializer()
            new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype)
            feat_placeholders[key] = new_data

        self.append(feat_placeholders)

Minjie Wang's avatar
Minjie Wang committed
612
613
614
615
616
617
618
619
620
621
622
    def update_rows(self, query, data, inplace):
        """Update the rows.

        If the provided data has new column, it will be added to the frame.

        See Also
        --------
        ``update_column``

        Parameters
        ----------
623
        query : utils.Index or slice
Minjie Wang's avatar
Minjie Wang committed
624
625
626
627
628
629
            The rows to be updated.
        data : dict-like
            The row data.
        inplace : bool
            True if the update is performed inplacely.
        """
630
        rows = self._getrows(query)
Minjie Wang's avatar
Minjie Wang committed
631
        for key, col in data.items():
632
633
            if key not in self:
                # add new column
634
                tmpref = FrameRef(self._frame, rows)
Minjie Wang's avatar
Minjie Wang committed
635
                tmpref.update_column(key, col, inplace)
Da Zheng's avatar
Da Zheng committed
636
            else:
637
                self._frame[key].update(rows, col, inplace)
Minjie Wang's avatar
Minjie Wang committed
638
639

    def __delitem__(self, key):
Minjie Wang's avatar
Minjie Wang committed
640
641
642
        """Delete data in the frame.

        If the provided key is a string, the corresponding column will be deleted.
643
644
        If the provided key is an index object or a slice, the corresponding rows will
        be deleted.
Minjie Wang's avatar
Minjie Wang committed
645
646
647
648
649
650
651
652
653
654

        Please note that "deleted" rows are not really deleted, but simply removed
        in the reference. As a result, if two FrameRefs point to the same Frame, deleting
        from one ref will not relect on the other. By contrast, deleting columns is real.

        Parameters
        ----------
        key : str or utils.Index
            The key.
        """
Minjie Wang's avatar
Minjie Wang committed
655
656
657
658
659
660
661
662
        if isinstance(key, str):
            del self._frame[key]
            if len(self._frame) == 0:
                self.clear()
        else:
            self.delete_rows(key)

    def delete_rows(self, query):
Minjie Wang's avatar
Minjie Wang committed
663
664
665
666
667
668
669
670
        """Delete rows.

        Please note that "deleted" rows are not really deleted, but simply removed
        in the reference. As a result, if two FrameRefs point to the same Frame, deleting
        from one ref will not relect on the other. By contrast, deleting columns is real.

        Parameters
        ----------
671
        query : utils.Index or slice
Minjie Wang's avatar
Minjie Wang committed
672
673
            The rows to be deleted.
        """
674
675
676
677
678
        if isinstance(query, slice):
            query = range(query.start, query.stop)
        else:
            query = query.tolist()

Minjie Wang's avatar
Minjie Wang committed
679
        if isinstance(self._index_data, slice):
680
681
            self._index_data = range(self._index_data.start, self._index_data.stop)
        self._index_data = list(np.delete(self._index_data, query))
Minjie Wang's avatar
Minjie Wang committed
682
683
684
        self._clear_cache()

    def append(self, other):
Minjie Wang's avatar
Minjie Wang committed
685
686
687
688
689
690
691
        """Append another frame into this one.

        Parameters
        ----------
        other : dict of str to tensor
            The data to be appended.
        """
Minjie Wang's avatar
Minjie Wang committed
692
693
694
695
696
697
        span_whole = self.is_span_whole_column()
        contiguous = self.is_contiguous()
        old_nrows = self._frame.num_rows
        self._frame.append(other)
        # update index
        if span_whole:
Minjie Wang's avatar
Minjie Wang committed
698
699
            self._index_data = slice(0, self._frame.num_rows)
        elif contiguous:
700
701
702
703
704
            if self._index_data.stop == old_nrows:
                new_idx = slice(self._index_data.start, self._frame.num_rows)
            else:
                new_idx = list(range(self._index_data.start, self._index_data.stop))
                new_idx.extend(range(old_nrows, self._frame.num_rows))
Minjie Wang's avatar
Minjie Wang committed
705
            self._index_data = new_idx
Minjie Wang's avatar
Minjie Wang committed
706
707
708
        self._clear_cache()

    def clear(self):
Minjie Wang's avatar
Minjie Wang committed
709
        """Clear the frame."""
Minjie Wang's avatar
Minjie Wang committed
710
        self._frame.clear()
Minjie Wang's avatar
Minjie Wang committed
711
        self._index_data = slice(0, 0)
Minjie Wang's avatar
Minjie Wang committed
712
        self._clear_cache()
713

Minjie Wang's avatar
Minjie Wang committed
714
    def is_contiguous(self):
Minjie Wang's avatar
Minjie Wang committed
715
        """Return whether this refers to a contiguous range of rows."""
716
717
        # NOTE: this check could have false negatives
        # NOTE: we always assume that slice.step is None
Minjie Wang's avatar
Minjie Wang committed
718
        return isinstance(self._index_data, slice)
Minjie Wang's avatar
Minjie Wang committed
719
720

    def is_span_whole_column(self):
Minjie Wang's avatar
Minjie Wang committed
721
        """Return whether this refers to all the rows."""
Minjie Wang's avatar
Minjie Wang committed
722
723
        return self.is_contiguous() and self.num_rows == self._frame.num_rows

724
    def _getrows(self, query):
Minjie Wang's avatar
Minjie Wang committed
725
        """Internal function to convert from the local row ids to the row ids of the frame."""
Minjie Wang's avatar
Minjie Wang committed
726
        if self.is_contiguous():
727
728
729
730
731
732
733
734
735
            start = self._index_data.start
            if start == 0:
                # shortcut for identical mapping
                return query
            elif isinstance(query, slice):
                return slice(query.start + start, query.stop + start)
            else:
                query = query.tousertensor()
                return utils.toindex(query + start)
Minjie Wang's avatar
Minjie Wang committed
736
        else:
Minjie Wang's avatar
Minjie Wang committed
737
            idxtensor = self.index().tousertensor()
738
739
            query = query.tousertensor()
            return utils.toindex(F.gather_row(idxtensor, query))
Minjie Wang's avatar
Minjie Wang committed
740
741

    def _clear_cache(self):
Minjie Wang's avatar
Minjie Wang committed
742
        """Internal function to clear the cached object."""
743
744
        self._index = None
        self._index_or_slice = None
Minjie Wang's avatar
Minjie Wang committed
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767

def merge_frames(frames, indices, max_index, reduce_func):
    """Merge a list of frames.

    The result frame contains `max_index` number of rows. For each frame in
    the given list, its row is merged as follows:

        merged[indices[i][row]] += frames[i][row]

    Parameters
    ----------
    frames : iterator of dgl.frame.FrameRef
        A list of frames to be merged.
    indices : iterator of dgl.utils.Index
        The indices of the frame rows.
    reduce_func : str
        The reduce function (only 'sum' is supported currently)

    Returns
    -------
    merged : FrameRef
        The merged frame.
    """
Minjie Wang's avatar
Minjie Wang committed
768
769
    # TODO(minjie)
    assert False, 'Buggy code, disabled for now.'
Minjie Wang's avatar
Minjie Wang committed
770
771
772
773
774
775
776
777
778
779
780
    assert reduce_func == 'sum'
    assert len(frames) > 0
    schemes = frames[0].schemes
    # create an adj to merge
    # row index is equal to the concatenation of all the indices.
    row = sum([idx.tolist() for idx in indices], [])
    col = list(range(len(row)))
    n = max_index
    m = len(row)
    row = F.unsqueeze(F.tensor(row, dtype=F.int64), 0)
    col = F.unsqueeze(F.tensor(col, dtype=F.int64), 0)
781
    idx = F.cat([row, col], dim=0)
Minjie Wang's avatar
Minjie Wang committed
782
783
784
785
786
787
788
789
790
791
792
    dat = F.ones((m,))
    adjmat = F.sparse_tensor(idx, dat, [n, m])
    ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
    merged = {}
    for key in schemes:
        # the rhs of the spmv is the concatenation of all the frame columns
        feats = F.pack([fr[key] for fr in frames])
        merged_feats = F.spmm(ctx_adjmat.get(F.get_context(feats)), feats)
        merged[key] = merged_feats
    merged = FrameRef(Frame(merged))
    return merged