Unverified Commit d107872a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] allow parameter aliases for local_listen_port, num_threads,...


[dask] allow parameter aliases for local_listen_port, num_threads, tree_learner (fixes #3671) (#3789)

* [dask] allow parameter aliases for tree_learner and local_listen_port (fixes #3671)

* num_thread too

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* empty commit
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 4007b34f
......@@ -238,6 +238,9 @@ class _ConfigAliases:
"sparse"},
"label_column": {"label_column",
"label"},
"local_listen_port": {"local_listen_port",
"local_port",
"port"},
"machines": {"machines",
"workers",
"nodes"},
......@@ -255,12 +258,21 @@ class _ConfigAliases:
"num_rounds",
"num_boost_round",
"n_estimators"},
"num_threads": {"num_threads",
"num_thread",
"nthread",
"nthreads",
"n_jobs"},
"objective": {"objective",
"objective_type",
"app",
"application"},
"pre_partition": {"pre_partition",
"is_pre_partition"},
"tree_learner": {"tree_learner",
"tree",
"tree_type",
"tree_learner_type"},
"two_round": {"two_round",
"two_round_loading",
"use_two_round_loading"},
......
......@@ -7,6 +7,7 @@ It is based on dask-xgboost package.
import logging
import socket
from collections import defaultdict
from copy import deepcopy
from typing import Dict, Iterable
from urllib.parse import urlparse
......@@ -19,7 +20,7 @@ from dask import dataframe as dd
from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait
from .basic import _LIB, _safe_call
from .basic import _ConfigAliases, _LIB, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor
logger = logging.getLogger(__name__)
......@@ -170,6 +171,8 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
Weights of training data.
"""
params = deepcopy(params)
# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
data_parts = _split_to_parts(data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False)
......@@ -197,21 +200,47 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
master_worker = next(iter(worker_map))
worker_ncores = client.ncores()
if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}:
logger.warning('Parameter tree_learner not set or set to incorrect value '
'(%s), using "data" as default', params.get("tree_learner", None))
tree_learner = None
for tree_learner_param in _ConfigAliases.get('tree_learner'):
tree_learner = params.get(tree_learner_param)
if tree_learner is not None:
break
allowed_tree_learners = {
'data',
'data_parallel',
'feature',
'feature_parallel',
'voting',
'voting_parallel'
}
if tree_learner is None:
logger.warning('Parameter tree_learner not set. Using "data" as default')
params['tree_learner'] = 'data'
elif tree_learner.lower() not in allowed_tree_learners:
logger.warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner)
params['tree_learner'] = 'data'
local_listen_port = 12400
for port_param in _ConfigAliases.get('local_listen_port'):
val = params.get(port_param)
if val is not None:
local_listen_port = val
break
# find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its
# own port
local_listen_port = params.get('local_listen_port', 12400)
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=local_listen_port
)
# num_threads is set below, so remove it and all aliases of it from params
for num_thread_alias in _ConfigAliases.get('num_threads'):
params.pop(num_thread_alias, None)
# Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part,
model_factory=model_factory,
......
......@@ -124,7 +124,7 @@ def test_classifier_local_predict(client, listen_port):
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
local_port=listen_port,
n_estimators=10,
num_leaves=10
)
......@@ -148,7 +148,8 @@ def test_regressor(output, client, listen_port):
time_out=5,
local_listen_port=listen_port,
seed=42,
num_leaves=10
num_leaves=10,
tree='data'
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX)
......@@ -181,7 +182,8 @@ def test_regressor_quantile(output, client, listen_port, alpha):
objective='quantile',
alpha=alpha,
n_estimators=10,
num_leaves=10
num_leaves=10,
tree_learner_type='data_parallel'
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX).compute()
......@@ -210,7 +212,8 @@ def test_regressor_local_predict(client, listen_port):
local_listen_port=listen_port,
seed=42,
n_estimators=10,
num_leaves=10
num_leaves=10,
tree_type='data'
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_regressor.predict(dX)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment