Unverified Commit 9f70e968 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] respect parameter aliases for network params (#3813)



* [dask] allow parameter aliases for tree_learner and local_listen_port (fixes #3671)

* num_thread too

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* empty commit

* add _choose_param_value

* revert param order change

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update python-package/lightgbm/dask.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* just import deepcopy

* remove machines aliases

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 5005b7be
# coding: utf-8 # coding: utf-8
"""Wrapper for C API of LightGBM.""" """Wrapper for C API of LightGBM."""
import copy
import ctypes import ctypes
import json import json
import os import os
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy
from functools import wraps from functools import wraps
from logging import Logger from logging import Logger
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any, Dict
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
...@@ -352,6 +353,46 @@ class _ConfigAliases: ...@@ -352,6 +353,46 @@ class _ConfigAliases:
return ret return ret
def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any) -> Dict[str, Any]:
"""Get a single parameter value, accounting for aliases.
Parameters
----------
main_param_name : str
Name of the main parameter to get a value for. One of the keys of ``_ConfigAliases``.
params : dict
Dictionary of LightGBM parameters.
default_value : Any
Default value to use for the parameter, if none is found in ``params``.
Returns
-------
params : dict
A ``params`` dict with exactly one value for ``main_param_name``, and all aliases ``main_param_name`` removed.
If both ``main_param_name`` and one or more aliases for it are found, the value of ``main_param_name`` will be preferred.
"""
# avoid side effects on passed-in parameters
params = deepcopy(params)
# find a value, and remove other aliases with .pop()
# prefer the value of 'main_param_name' if it exists, otherwise search the aliases
found_value = None
if main_param_name in params.keys():
found_value = params[main_param_name]
for param in _ConfigAliases.get(main_param_name):
val = params.pop(param, None)
if found_value is None and val is not None:
found_value = val
if found_value is not None:
params[main_param_name] = found_value
else:
params[main_param_name] = default_value
return params
MAX_INT32 = (1 << 31) - 1 MAX_INT32 = (1 << 31) - 1
"""Macro definition of data type in C API of LightGBM""" """Macro definition of data type in C API of LightGBM"""
...@@ -1052,7 +1093,7 @@ class Dataset: ...@@ -1052,7 +1093,7 @@ class Dataset:
self.silent = silent self.silent = silent
self.feature_name = feature_name self.feature_name = feature_name
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
self.params = copy.deepcopy(params) self.params = deepcopy(params)
self.free_raw_data = free_raw_data self.free_raw_data = free_raw_data
self.used_indices = None self.used_indices = None
self.need_slice = True self.need_slice = True
...@@ -1510,13 +1551,13 @@ class Dataset: ...@@ -1510,13 +1551,13 @@ class Dataset:
def _update_params(self, params): def _update_params(self, params):
if not params: if not params:
return self return self
params = copy.deepcopy(params) params = deepcopy(params)
def update(): def update():
if not self.params: if not self.params:
self.params = params self.params = params
else: else:
self.params_back_up = copy.deepcopy(self.params) self.params_back_up = deepcopy(self.params)
self.params.update(params) self.params.update(params)
if self.handle is None: if self.handle is None:
...@@ -1536,7 +1577,7 @@ class Dataset: ...@@ -1536,7 +1577,7 @@ class Dataset:
def _reverse_update_params(self): def _reverse_update_params(self):
if self.handle is None: if self.handle is None:
self.params = copy.deepcopy(self.params_back_up) self.params = deepcopy(self.params_back_up)
self.params_back_up = None self.params_back_up = None
return self return self
...@@ -2130,7 +2171,7 @@ class Booster: ...@@ -2130,7 +2171,7 @@ class Booster:
self.__set_objective_to_none = False self.__set_objective_to_none = False
self.best_iteration = -1 self.best_iteration = -1
self.best_score = {} self.best_score = {}
params = {} if params is None else copy.deepcopy(params) params = {} if params is None else deepcopy(params)
# user can set verbose with params, it has higher priority # user can set verbose with params, it has higher priority
if not any(verbose_alias in params for verbose_alias in _ConfigAliases.get("verbosity")) and silent: if not any(verbose_alias in params for verbose_alias in _ConfigAliases.get("verbosity")) and silent:
params["verbose"] = -1 params["verbose"] = -1
...@@ -2139,22 +2180,40 @@ class Booster: ...@@ -2139,22 +2180,40 @@ class Booster:
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError('Training data should be Dataset instance, met {}' raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__)) .format(type(train_set).__name__))
# set network if necessary params = _choose_param_value(
for alias in _ConfigAliases.get("machines"): main_param_name="machines",
if alias in params: params=params,
machines = params[alias] default_value=None
if isinstance(machines, str): )
num_machines = len(machines.split(',')) # if "machines" is given, assume user wants to do distributed learning, and set up network
elif isinstance(machines, (list, set)): if params["machines"] is None:
num_machines = len(machines) params.pop("machines", None)
machines = ','.join(machines) else:
else: machines = params["machines"]
raise ValueError("Invalid machines in params.") if isinstance(machines, str):
self.set_network(machines, num_machines_from_machine_list = len(machines.split(','))
local_listen_port=params.get("local_listen_port", 12400), elif isinstance(machines, (list, set)):
listen_time_out=params.get("listen_time_out", 120), num_machines_from_machine_list = len(machines)
num_machines=params.setdefault("num_machines", num_machines)) machines = ','.join(machines)
break else:
raise ValueError("Invalid machines in params.")
params = _choose_param_value(
main_param_name="num_machines",
params=params,
default_value=num_machines_from_machine_list
)
params = _choose_param_value(
main_param_name="local_listen_port",
params=params,
default_value=12400
)
self.set_network(
machines=machines,
local_listen_port=params["local_listen_port"],
listen_time_out=params.get("time_out", 120),
num_machines=params["num_machines"]
)
# construct booster object # construct booster object
train_set.construct() train_set.construct()
# copy the parameters from train_set # copy the parameters from train_set
...@@ -3056,7 +3115,7 @@ class Booster: ...@@ -3056,7 +3115,7 @@ class Booster:
Prediction result. Prediction result.
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
""" """
predictor = self._to_predictor(copy.deepcopy(kwargs)) predictor = self._to_predictor(deepcopy(kwargs))
if num_iteration is None: if num_iteration is None:
if start_iteration <= 0: if start_iteration <= 0:
num_iteration = self.best_iteration num_iteration = self.best_iteration
...@@ -3090,14 +3149,14 @@ class Booster: ...@@ -3090,14 +3149,14 @@ class Booster:
""" """
if self.__set_objective_to_none: if self.__set_objective_to_none:
raise LightGBMError('Cannot refit due to null objective function.') raise LightGBMError('Cannot refit due to null objective function.')
predictor = self._to_predictor(copy.deepcopy(kwargs)) predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True) leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_bool(False) out_is_linear = ctypes.c_bool(False)
_safe_call(_LIB.LGBM_BoosterGetLinear( _safe_call(_LIB.LGBM_BoosterGetLinear(
self.handle, self.handle,
ctypes.byref(out_is_linear))) ctypes.byref(out_is_linear)))
new_params = copy.deepcopy(self.params) new_params = deepcopy(self.params)
new_params["linear_tree"] = out_is_linear.value new_params["linear_tree"] = out_is_linear.value
train_set = Dataset(data, label, silent=True, params=new_params) train_set = Dataset(data, label, silent=True, params=new_params)
new_params['refit_decay_rate'] = decay_rate new_params['refit_decay_rate'] = decay_rate
......
...@@ -21,7 +21,7 @@ from dask import dataframe as dd ...@@ -21,7 +21,7 @@ from dask import dataframe as dd
from dask import delayed from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait from dask.distributed import Client, default_client, get_worker, wait
from .basic import _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED from .compat import DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker
...@@ -196,6 +196,44 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group ...@@ -196,6 +196,44 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
""" """
params = deepcopy(params) params = deepcopy(params)
params = _choose_param_value(
main_param_name="local_listen_port",
params=params,
default_value=12400
)
params = _choose_param_value(
main_param_name="tree_learner",
params=params,
default_value="data"
)
allowed_tree_learners = {
'data',
'data_parallel',
'feature',
'feature_parallel',
'voting',
'voting_parallel'
}
if params["tree_learner"] not in allowed_tree_learners:
_log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner)
params['tree_learner'] = 'data'
if params['tree_learner'] not in {'data', 'data_parallel'}:
_log_warning(
'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. \n'
'Use "data" for a stable, well-tested interface.' % params['tree_learner']
)
# Some passed-in parameters can be removed:
# * 'machines': constructed automatically from Dask worker list
# * 'machine_list_filename': not relevant for the Dask interface
# * 'num_machines': set automatically from Dask worker list
# * 'num_threads': overridden to match nthreads on each Dask process
for param_name in ['machines', 'machine_list_filename', 'num_machines', 'num_threads']:
for param_alias in _ConfigAliases.get(param_name):
params.pop(param_alias, None)
# Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality # Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
data_parts = _split_to_parts(data=data, is_matrix=True) data_parts = _split_to_parts(data=data, is_matrix=True)
label_parts = _split_to_parts(data=label, is_matrix=False) label_parts = _split_to_parts(data=label, is_matrix=False)
...@@ -230,65 +268,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group ...@@ -230,65 +268,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
master_worker = next(iter(worker_map)) master_worker = next(iter(worker_map))
worker_ncores = client.ncores() worker_ncores = client.ncores()
tree_learner = None
for tree_learner_param in _ConfigAliases.get('tree_learner'):
tree_learner = params.get(tree_learner_param)
if tree_learner is not None:
params['tree_learner'] = tree_learner
break
allowed_tree_learners = {
'data',
'data_parallel',
'feature',
'feature_parallel',
'voting',
'voting_parallel'
}
if tree_learner is None:
_log_warning('Parameter tree_learner not set. Using "data" as default')
params['tree_learner'] = 'data'
elif tree_learner.lower() not in allowed_tree_learners:
_log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner)
params['tree_learner'] = 'data'
if params['tree_learner'] not in {'data', 'data_parallel'}:
_log_warning(
'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. Use "data" for a stable, well-tested interface.' % params['tree_learner']
)
local_listen_port = 12400
for port_param in _ConfigAliases.get('local_listen_port'):
val = params.get(port_param)
if val is not None:
local_listen_port = val
break
# find an open port on each worker. note that multiple workers can run # find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its # on the same machine, so this needs to ensure that each one gets its
# own port # own port
worker_address_to_port = _find_ports_for_workers( worker_address_to_port = _find_ports_for_workers(
client=client, client=client,
worker_addresses=worker_map.keys(), worker_addresses=worker_map.keys(),
local_listen_port=local_listen_port local_listen_port=params["local_listen_port"]
) )
# num_threads is set below, so remove it and all aliases of it from params
for num_thread_alias in _ConfigAliases.get('num_threads'):
params.pop(num_thread_alias, None)
# machines is constructed manually, so remove it and all aliases of it from params
for machine_alias in _ConfigAliases.get('machines'):
params.pop(machine_alias, None)
# machines is constructed manually, so remove machine_list_filename and all aliases of it from params
for machine_list_filename_alias in _ConfigAliases.get('machine_list_filename'):
params.pop(machine_list_filename_alias, None)
# machines is constructed manually, so remove num_machines and all aliases of it from params
for num_machine_alias in _ConfigAliases.get('num_machines'):
params.pop(num_machine_alias, None)
# Tell each worker to train on the parts that it has locally # Tell each worker to train on the parts that it has locally
futures_classifiers = [ futures_classifiers = [
client.submit( client.submit(
......
...@@ -329,3 +329,49 @@ def test_consistent_state_for_dataset_fields(): ...@@ -329,3 +329,49 @@ def test_consistent_state_for_dataset_fields():
lgb_data.set_init_score(sequence) lgb_data.set_init_score(sequence)
lgb_data.set_feature_name(feature_names) lgb_data.set_feature_name(feature_names)
check_asserts(lgb_data) check_asserts(lgb_data)
def test_choose_param_value():
original_params = {
"local_listen_port": 1234,
"port": 2222,
"metric": "auc",
"num_trees": 81
}
# should resolve duplicate aliases, and prefer the main parameter
params = lgb.basic._choose_param_value(
main_param_name="local_listen_port",
params=original_params,
default_value=5555
)
assert params["local_listen_port"] == 1234
assert "port" not in params
# should choose a value from an alias and set that value on main param
# if only an alias is used
params = lgb.basic._choose_param_value(
main_param_name="num_iterations",
params=params,
default_value=17
)
assert params["num_iterations"] == 81
assert "num_trees" not in params
# should use the default if main param and aliases are missing
params = lgb.basic._choose_param_value(
main_param_name="learning_rate",
params=params,
default_value=0.789
)
assert params["learning_rate"] == 0.789
# all changes should be made on copies and not modify the original
expected_params = {
"local_listen_port": 1234,
"port": 2222,
"metric": "auc",
"num_trees": 81
}
assert original_params == expected_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment