Unverified Commit 851d66fa authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix duplicate worker_init_fn argument when provided in DataLoader (#5420)

* fix duplicate worker_init_fn

* lint

* lint again

* uugh
parent 26b245a0
......@@ -1018,7 +1018,7 @@ class DataLoader(torch.utils.data.DataLoader):
self.use_prefetch_thread = use_prefetch_thread
self.cpu_affinity_enabled = False
worker_init_fn = WorkerInitWrapper(kwargs.get("worker_init_fn", None))
worker_init_fn = WorkerInitWrapper(kwargs.pop("worker_init_fn", None))
self.other_storages = {}
......
......@@ -622,6 +622,26 @@ def test_edge_dataloader_excludes(
break
def dummy_worker_init_fn(worker_id):
pass
def test_dataloader_worker_init_fn():
dataset = dgl.data.CoraFullDataset()
g = dataset[0]
sampler = dgl.dataloading.MultiLayerNeighborSampler([2])
dataloader = dgl.dataloading.DataLoader(
g,
torch.arange(100),
sampler,
batch_size=4,
num_workers=4,
worker_init_fn=dummy_worker_init_fn,
)
for _ in dataloader:
pass
if __name__ == "__main__":
# test_node_dataloader(F.int32, 'neighbor', None)
test_edge_dataloader_excludes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment