Unverified Commit 851d66fa authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix duplicate worker_init_fn argument when provided in DataLoader (#5420)

* fix duplicate worker_init_fn

* lint

* lint again

* uugh
parent 26b245a0
...@@ -1018,7 +1018,7 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -1018,7 +1018,7 @@ class DataLoader(torch.utils.data.DataLoader):
self.use_prefetch_thread = use_prefetch_thread self.use_prefetch_thread = use_prefetch_thread
self.cpu_affinity_enabled = False self.cpu_affinity_enabled = False
worker_init_fn = WorkerInitWrapper(kwargs.get("worker_init_fn", None)) worker_init_fn = WorkerInitWrapper(kwargs.pop("worker_init_fn", None))
self.other_storages = {} self.other_storages = {}
......
...@@ -622,6 +622,26 @@ def test_edge_dataloader_excludes( ...@@ -622,6 +622,26 @@ def test_edge_dataloader_excludes(
break break
def dummy_worker_init_fn(worker_id):
pass
def test_dataloader_worker_init_fn():
dataset = dgl.data.CoraFullDataset()
g = dataset[0]
sampler = dgl.dataloading.MultiLayerNeighborSampler([2])
dataloader = dgl.dataloading.DataLoader(
g,
torch.arange(100),
sampler,
batch_size=4,
num_workers=4,
worker_init_fn=dummy_worker_init_fn,
)
for _ in dataloader:
pass
if __name__ == "__main__": if __name__ == "__main__":
# test_node_dataloader(F.int32, 'neighbor', None) # test_node_dataloader(F.int32, 'neighbor', None)
test_edge_dataloader_excludes( test_edge_dataloader_excludes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment