Unverified Commit 47993776 authored by Daniil Sizov's avatar Daniil Sizov Committed by GitHub
Browse files

[Feature] Rework Dataloader cpu affinitization as helper method (#4126)



* Add helper method for temporary affinitization of compute threads

* Rework DL affinitization as single helper

* Add example usage in benchmarks

* Fix python linter warnings

* Fix affinity helper params

* Use NUMA node 0 cores only by default

* Fix benchmarks

* Fix lint errors
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent bc2fef9c
......@@ -114,11 +114,14 @@ def track_time(data):
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
# Enable dataloader cpu affinitization for cpu devices (no effect on gpu)
with dataloader.enable_cpu_affinity():
# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
# Training loop
avg = 0
iter_tput = []
# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
# Load the input features as well as output labels
blocks = [block.int().to(device) for block in blocks]
......
......@@ -288,6 +288,8 @@ def track_time(data):
optimizer.zero_grad()
sparse_optimizer.zero_grad()
# Enable dataloader cpu affinitization for cpu devices (no effect on gpu)
with loader.enable_cpu_affinity():
for step, (input_nodes, seeds, blocks) in enumerate(loader):
blocks = [blk.to(device) for blk in blocks]
seeds = seeds[category] # we only predict the nodes with type "category"
......
......@@ -283,6 +283,8 @@ def track_time(data):
model.train()
embed_layer.train()
# Enable dataloader cpu affinitization for cpu devices (no effect on gpu)
with loader.enable_cpu_affinity():
for step, sample_data in enumerate(loader):
input_nodes, output_nodes, blocks = sample_data
feats = embed_layer(input_nodes,
......
......@@ -94,9 +94,12 @@ def track_time(data):
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
# Enable dataloader cpu affinitization for cpu devices (no effect on gpu)
with dataloader.enable_cpu_affinity():
# Training loop
avg = 0
iter_tput = []
for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
# Load the input features as well as output labels
#batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device)
......
......@@ -9,6 +9,7 @@ import inspect
import re
import atexit
import os
from contextlib import contextmanager
import psutil
import numpy as np
......@@ -20,8 +21,8 @@ from ..base import NID, EID, dgl_warning, DGLError
from ..batch import batch as batch_graphs
from ..heterograph import DGLHeteroGraph
from ..utils import (
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads,
context_of, dtype_of)
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads, get_num_threads,
get_numa_nodes_cores, context_of, dtype_of)
from ..frame import LazyFeature
from ..storages import wrap_storage
from .base import BlockSampler, as_edge_prediction_sampler
......@@ -697,8 +698,7 @@ class DataLoader(torch.utils.data.DataLoader):
def __init__(self, graph, indices, graph_sampler, device=None, use_ddp=False,
ddp_seed=0, batch_size=1, drop_last=False, shuffle=False,
use_prefetch_thread=None, use_alternate_streams=None,
pin_prefetcher=None, use_uva=False,
use_cpu_worker_affinity=False, cpu_worker_affinity_cores=None, **kwargs):
pin_prefetcher=None, use_uva=False, **kwargs):
# (BarclayII) PyTorch Lightning sometimes will recreate a DataLoader from an existing
# DataLoader with modifications to the original arguments. The arguments are retrieved
# from the attributes with the same name, and because we change certain arguments
......@@ -840,31 +840,12 @@ class DataLoader(torch.utils.data.DataLoader):
self.use_alternate_streams = use_alternate_streams
self.pin_prefetcher = pin_prefetcher
self.use_prefetch_thread = use_prefetch_thread
self.cpu_affinity_enabled = False
worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None))
self.other_storages = {}
if use_cpu_worker_affinity:
nw_work = kwargs.get('num_workers', 0)
if cpu_worker_affinity_cores is None:
cpu_worker_affinity_cores = []
if not isinstance(cpu_worker_affinity_cores, list):
raise Exception('ERROR: cpu_worker_affinity_cores should be a list of cores')
if not nw_work > 0:
raise Exception('ERROR: affinity should be used with --num_workers=X')
if len(cpu_worker_affinity_cores) not in [0, nw_work]:
raise Exception('ERROR: cpu_affinity incorrect '
'settings for cores={} num_workers={}'
.format(cpu_worker_affinity_cores, nw_work))
self.cpu_cores = (cpu_worker_affinity_cores
if len(cpu_worker_affinity_cores)
else range(0, nw_work))
worker_init_fn = WorkerInitWrapper(self.worker_init_function)
super().__init__(
self.dataset,
collate_fn=CollateWrapper(
......@@ -875,6 +856,11 @@ class DataLoader(torch.utils.data.DataLoader):
**kwargs)
def __iter__(self):
if self.device.type == 'cpu' and not self.cpu_affinity_enabled:
link = 'https://docs.dgl.ai/tutorials/cpu/cpu_best_practises.html'
dgl_warning(f'Dataloader CPU affinity opt is not enabled, consider switching it on '
f'(see enable_cpu_affinity() or CPU best practices for DGL [{link}])')
if self.shuffle:
self.dataset.shuffle()
# When using multiprocessing PyTorch sometimes set the number of PyTorch threads to 1
......@@ -882,20 +868,89 @@ class DataLoader(torch.utils.data.DataLoader):
num_threads = torch.get_num_threads() if self.num_workers > 0 else None
return _PrefetchingIter(self, super().__iter__(), num_threads=num_threads)
def worker_init_function(self, worker_id):
"""Worker init default function.
@contextmanager
def enable_cpu_affinity(self, loader_cores=None, compute_cores=None, verbose=True):
""" Helper method for enabling cpu affinity for compute threads and dataloader workers
Only for CPU devices
Uses only NUMA node 0 by default for multi-node systems
Parameters
----------
worker_id : int
Worker ID.
loader_cores : [int] (optional)
List of cpu cores to which dataloader workers should affinitize to.
default: node0_cores[0:num_workers]
compute_cores : [int] (optional)
List of cpu cores to which compute threads should affinitize to
default: node0_cores[num_workers:]
verbose : bool (optional)
If True, affinity information will be printed to the console
Usage
-----
with dataloader.enable_cpu_affinity():
<training loop>
"""
if self.device.type == 'cpu':
if not self.num_workers > 0:
raise Exception('ERROR: affinity should be used with at least one DL worker')
if loader_cores and len(loader_cores) != self.num_workers:
raise Exception('ERROR: cpu_affinity incorrect '
'number of loader_cores={} for num_workers={}'
.format(loader_cores, self.num_workers))
# False positive E0203 (access-member-before-definition) linter warning
worker_init_fn_old = self.worker_init_fn # pylint: disable=E0203
affinity_old = psutil.Process().cpu_affinity()
nthreads_old = get_num_threads()
compute_cores = compute_cores[:] if compute_cores else []
loader_cores = loader_cores[:] if loader_cores else []
def init_fn(worker_id):
try:
psutil.Process().cpu_affinity([self.cpu_cores[worker_id]])
print('CPU-affinity worker {} has been assigned to core={}'
.format(worker_id, self.cpu_cores[worker_id]))
psutil.Process().cpu_affinity([loader_cores[worker_id]])
except:
raise Exception('ERROR: cannot use affinity id={} cpu_cores={}'
.format(worker_id, self.cpu_cores))
raise Exception('ERROR: cannot use affinity id={} cpu={}'
.format(worker_id, loader_cores))
worker_init_fn_old(worker_id)
if not loader_cores or not compute_cores:
numa_info = get_numa_nodes_cores()
if numa_info and len(numa_info[0]) > self.num_workers:
# take one thread per each node 0 core
node0_cores = [cpus[0] for core_id, cpus in numa_info[0]]
else:
node0_cores = list(range(psutil.cpu_count(logical = False)))
if len(node0_cores) <= self.num_workers:
raise Exception('ERROR: more workers than available cores')
loader_cores = loader_cores or node0_cores[0:self.num_workers]
compute_cores = [cpu for cpu in node0_cores if cpu not in loader_cores]
try:
psutil.Process().cpu_affinity(compute_cores)
set_num_threads(len(compute_cores))
self.worker_init_fn = init_fn
self.cpu_affinity_enabled = True
if verbose:
print('{} DL workers are assigned to cpus {}, main process will use cpus {}'
.format(self.num_workers, loader_cores, compute_cores))
yield
finally:
# restore omp_num_threads and cpu affinity
psutil.Process().cpu_affinity(affinity_old)
set_num_threads(nthreads_old)
self.worker_init_fn = worker_init_fn_old
self.cpu_affinity_enabled = False
else:
yield
# To allow data other than node/edge data to be prefetched.
def attach_data(self, name, data):
......
......@@ -4,6 +4,8 @@ from __future__ import absolute_import, division
from collections.abc import Mapping, Iterable, Sequence
from collections import defaultdict
from functools import wraps
import glob
import os
import numpy as np
from ..base import DGLError, dgl_warning, NID, EID
......@@ -914,6 +916,46 @@ def set_num_threads(num_threads):
"""
_CAPI_DGLSetOMPThreads(num_threads)
def get_num_threads():
"""Get the number of OMP threads in the process"""
return _CAPI_DGLGetOMPThreads()
def get_numa_nodes_cores():
""" Returns numa nodes info, format:
{<node_id>: [(<core_id>, [<sibling_thread_id_0>, <sibling_thread_id_1>, ...]), ...], ...}
E.g.: {0: [(0, [0, 4]), (1, [1, 5])], 1: [(2, [2, 6]), (3, [3, 7])]}
If not available, returns {}
"""
numa_node_paths = glob.glob('/sys/devices/system/node/node[0-9]*')
if not numa_node_paths:
return {}
nodes = {}
try:
for node_path in numa_node_paths:
numa_node_id = int(os.path.basename(node_path)[4:])
thread_siblings = {}
for cpu_dir in glob.glob(os.path.join(node_path, 'cpu[0-9]*')):
cpu_id = int(os.path.basename(cpu_dir)[3:])
with open(os.path.join(cpu_dir, 'topology', 'core_id')) as core_id_file:
core_id = int(core_id_file.read().strip())
if core_id in thread_siblings:
thread_siblings[core_id].append(cpu_id)
else:
thread_siblings[core_id] = [cpu_id]
nodes[numa_node_id] = sorted([(k, sorted(v)) for k, v in thread_siblings.items()])
except (OSError, ValueError, IndexError, IOError):
dgl_warning('Failed to read NUMA info')
return {}
return nodes
def alias_func(func):
"""Return an alias function with proper docstring."""
@wraps(func)
......
......@@ -26,6 +26,10 @@ DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLSetOMPThreads")
omp_set_num_threads(num_threads);
});
DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLGetOMPThreads")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = omp_get_max_threads();
});
DGL_REGISTER_GLOBAL("utils.checks._CAPI_DGLCOOIsSorted")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
......
......@@ -22,25 +22,39 @@ OpenMP settings
During training on CPU, the training and dataloading part need to be maintained simultaneously.
Best performance of parallelization in OpenMP
can be achieved by setting up the optimal number of working threads and dataloading workers.
Nodes with high number of CPU cores may benefit from higher number of dataloading workers.
A good starting point could be setting num_threads=4 in Dataloader constructor for nodes with 32 cores or more.
If number of cores is rather small, the best performance might be achieved with just one
dataloader worker or even with dataloader num_threads=0 for dataloading and trainig performed
in the same process
**GNU OpenMP**
Default BKM for setting the number of OMP threads with Pytorch backend:
**Dataloader CPU affinity**
``OMP_NUM_THREADS`` = number of physical cores – ``num_workers``
If number of dataloader workers is more than 0, please consider using **use_cpu_affinity()** method
of DGL Dataloader class, it will generally result in significant performance improvement for training.
Number of physical cores can be checked by using ``lscpu`` ("Core(s) per socket")
or ``nproc`` command in Linux command line.
Below simple bash script example for setting the OMP threads and ``pytorch`` backend dataloader workers:
*use_cpu_affinity* will set the proper OpenMP thread count (equal to the number of CPU cores allocated for main process),
affinitize dataloader workers for separate CPU cores and restrict the main process to remaining cores
.. code:: bash
In multiple NUMA nodes setups *use_cpu_affinity* will only use cores of NUMA node 0 by default
with an assumption, that the workload is scaling poorly across multiple NUMA nodes. If you believe
your workload will have better performance utilizing more than one NUMA node, you can pass
the list of cores to use for dataloading (loader_cores) and for compute (compute_cores).
cores=`nproc`
num_workers=4
export OMP_NUM_THREADS=$(($cores-$num_workers))
python script.py --gpu -1 --num_workers=$num_workers
loader_cores and compute_cores arguments (list of CPU cores) can be passed to *enable_cpu_affinity* for more
control over which cores should be used, e.g. in case a workload scales well across multiple NUMA nodes.
Depending on the dataset, model and CPU optimal number of dataloader workers and OpemMP threads may vary
but close to the general default advise presented above [#f4]_ .
Usage:
.. code:: python
dataloader = dgl.dataloading.DataLoader(...)
...
with dataloader.enable_cpu_affinity():
<training loop or inferencing>
**Manual control**
For advanced and more fine-grained control over OpenMP settings please refer to Maximize Performance of Intel® Optimization for PyTorch* on CPU [#f4]_ article
.. rubric:: Footnotes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment