pytorch_tensor.py 1.92 KB
Newer Older
1
2
3
"""Feature storages for PyTorch tensors."""

import torch
4
5
6
from .base import register_storage_wrapper
from .tensor import BaseTensorStorage
from ..utils import gather_pinned_tensor_rows
7

8
def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory, **kwargs):
9
10
11
12
    result = torch.empty(
        indices.shape[0], *feature_shape, dtype=tensor.dtype,
        pin_memory=pin_memory)
    torch.index_select(tensor, 0, indices, out=result)
13
14
    kwargs['non_blocking'] = pin_memory
    result = result.to(device, **kwargs)
15
16
    return result

17
18
def _fetch_cuda(indices, tensor, device, **kwargs):
    return torch.index_select(tensor, 0, indices).to(device, **kwargs)
19
20

@register_storage_wrapper(torch.Tensor)
21
class PyTorchTensorStorage(BaseTensorStorage):
22
    """Feature storages for slicing a PyTorch tensor."""
23
    def fetch(self, indices, device, pin_memory=False, **kwargs):
24
        device = torch.device(device)
25
26
27
28
29
30
31
32
33
34
35
36
        storage_device_type = self.storage.device.type
        indices_device_type = indices.device.type
        if storage_device_type != 'cuda':
            if indices_device_type == 'cuda':
                if self.storage.is_pinned():
                    return gather_pinned_tensor_rows(self.storage, indices)
                else:
                    raise ValueError(
                        f'Got indices on device {indices.device} whereas the feature tensor '
                        f'is on {self.storage.device}. Please either (1) move the graph '
                        f'to GPU with to() method, or (2) pin the graph with '
                        f'pin_memory_() method.')
37
            # CPU to CPU or CUDA - use pin_memory and async transfer if possible
38
39
            else:
                return _fetch_cpu(indices, self.storage, self.storage.shape[1:], device,
40
                                  pin_memory, **kwargs)
41
42
        else:
            # CUDA to CUDA or CPU
43
            return _fetch_cuda(indices, self.storage, device, **kwargs)