Unverified Commit b8f905f1 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Fix that pin_prefetcher is not actually enabled (#4169)

parent 10db5d0b
...@@ -16,7 +16,7 @@ import torch ...@@ -16,7 +16,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from ..base import NID, EID, dgl_warning from ..base import NID, EID, dgl_warning, DGLError
from ..batch import batch as batch_graphs from ..batch import batch as batch_graphs
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from ..utils import ( from ..utils import (
...@@ -870,6 +870,7 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -870,6 +870,7 @@ class DataLoader(torch.utils.data.DataLoader):
collate_fn=CollateWrapper( collate_fn=CollateWrapper(
self.graph_sampler.sample, graph, self.use_uva, self.device), self.graph_sampler.sample, graph, self.use_uva, self.device),
batch_size=None, batch_size=None,
pin_memory=self.pin_prefetcher,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
**kwargs) **kwargs)
......
...@@ -438,7 +438,7 @@ class Column(TensorStorage): ...@@ -438,7 +438,7 @@ class Column(TensorStorage):
def fetch(self, indices, device, pin_memory=False, **kwargs): def fetch(self, indices, device, pin_memory=False, **kwargs):
_ = self.data # materialize in case of lazy slicing & data transfer _ = self.data # materialize in case of lazy slicing & data transfer
return super().fetch(indices, device, pin_memory=False, **kwargs) return super().fetch(indices, device, pin_memory=pin_memory, **kwargs)
def pin_memory_(self): def pin_memory_(self):
"""Pin the storage into page-locked memory. """Pin the storage into page-locked memory.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment