"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e3ddbe25edeadaa5afc3f8f5bb0d645098a8b26a"
Unverified Commit c94d2a5f authored by Andrzej Kotłowski's avatar Andrzej Kotłowski Committed by GitHub
Browse files

[Performance] Do not fuse neighbor sampler for 1 thread (#6421)

parent ad4df9c5
...@@ -3,6 +3,7 @@ from .. import backend as F ...@@ -3,6 +3,7 @@ from .. import backend as F
from ..base import EID, NID from ..base import EID, NID
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from ..transforms import to_block from ..transforms import to_block
from ..utils import get_num_threads
from .base import BlockSampler from .base import BlockSampler
...@@ -150,8 +151,9 @@ class NeighborSampler(BlockSampler): ...@@ -150,8 +151,9 @@ class NeighborSampler(BlockSampler):
def sample_blocks(self, g, seed_nodes, exclude_eids=None): def sample_blocks(self, g, seed_nodes, exclude_eids=None):
output_nodes = seed_nodes output_nodes = seed_nodes
blocks = [] blocks = []
# sample_neighbors_fused function requires multithreading to be more efficient
if self.fused: # than sample_neighbors
if self.fused and get_num_threads() > 1:
cpu = F.device_type(g.device) == "cpu" cpu = F.device_type(g.device) == "cpu"
if isinstance(seed_nodes, dict): if isinstance(seed_nodes, dict):
for ntype in list(seed_nodes.keys()): for ntype in list(seed_nodes.keys()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment