Unverified Commit 3c7e7e0b authored by Frank Fineis's avatar Frank Fineis Committed by GitHub
Browse files

[python-package] [dask] Add DaskLGBMRanker (#3708)



* ranker support wip

* fix ranker tests

* fix _make_ranking rnd gen bug, add sleep to help w stoch binding port failed exceptions

* add wait_for_workers to prevent Binding port exception

* another attempt to stabilize test_dask.py

* requested changes: docstrings, dask_ml, tuples for list_of_parts

* fix lint bug, add group param to test_ranker_local_predict

* decorator to skip tests with errors on fixture teardown

* remove gpu ranker tests, reduce make_ranking data complexity

* another attempt to
silence client, decorator does not silence fixture errors

* address requested changes on 1/20/20

* skip test_dask for all GPU tasks

* address changes requested on 1/21/21

* issubclass instead of __qualname__
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* parity in group docstr with sklearn
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* _make_ranking docstr cleanup
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 6dbe736e
......@@ -21,7 +21,7 @@ from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait
from .basic import _ConfigAliases, _LIB, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker
logger = logging.getLogger(__name__)
......@@ -133,15 +133,24 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
}
params.update(network_params)
is_ranker = issubclass(model_factory, LGBMRanker)
# Concatenate many parts into one
parts = tuple(zip(*list_of_parts))
data = _concat(parts[0])
label = _concat(parts[1])
weight = _concat(parts[2]) if len(parts) == 3 else None
try:
model = model_factory(**params)
model.fit(data, label, sample_weight=weight, **kwargs)
if is_ranker:
group = _concat(parts[-1])
weight = _concat(parts[2]) if len(parts) == 4 else None
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs)
else:
weight = _concat(parts[2]) if len(parts) == 3 else None
model.fit(data, y=label, sample_weight=weight, **kwargs)
finally:
_safe_call(_LIB.LGBM_NetworkFree())
......@@ -156,7 +165,7 @@ def _split_to_parts(data, is_matrix):
return parts
def _train(client, data, label, params, model_factory, weight=None, **kwargs):
def _train(client, data, label, params, model_factory, sample_weight=None, group=None, **kwargs):
"""Inner train routine.
Parameters
......@@ -167,22 +176,36 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
y : dask array of shape = [n_samples]
The target values (class labels in classification, real numbers in regression).
params : dict
model_factory : lightgbm.LGBMClassifier or lightgbm.LGBMRegressor class
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
Weights of training data.
Weights of training data.
group : array-like or None, optional (default=None)
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
"""
params = deepcopy(params)
# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
data_parts = _split_to_parts(data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False)
if weight is None:
parts = list(map(delayed, zip(data_parts, label_parts)))
weight_parts = _split_to_parts(sample_weight, is_matrix=False) if sample_weight is not None else None
group_parts = _split_to_parts(group, is_matrix=False) if group is not None else None
# choose between four options of (sample_weight, group) being (un)specified
if weight_parts is None and group_parts is None:
parts = zip(data_parts, label_parts)
elif weight_parts is not None and group_parts is None:
parts = zip(data_parts, label_parts, weight_parts)
elif weight_parts is None and group_parts is not None:
parts = zip(data_parts, label_parts, group_parts)
else:
weight_parts = _split_to_parts(weight, is_matrix=False)
parts = list(map(delayed, zip(data_parts, label_parts, weight_parts)))
parts = zip(data_parts, label_parts, weight_parts, group_parts)
# Start computation in the background
parts = list(map(delayed, parts))
parts = client.compute(parts)
wait(parts)
......@@ -281,13 +304,13 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs):
Parameters
----------
model :
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
data : dask array of shape = [n_samples, n_features]
Input feature matrix.
proba : bool
Should method return results of predict_proba (proba == True) or predict (proba == False)
Should method return results of predict_proba (proba == True) or predict (proba == False).
dtype : np.dtype
Dtype of the output
Dtype of the output.
kwargs : other parameters passed to predict or predict_proba method
"""
if isinstance(data, dd._Frame):
......@@ -304,13 +327,14 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs):
class _LGBMModel:
def _fit(self, model_factory, X, y=None, sample_weight=None, client=None, **kwargs):
def _fit(self, model_factory, X, y=None, sample_weight=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
if client is None:
client = default_client()
params = self.get_params(True)
model = _train(client, X, y, params, model_factory, sample_weight, **kwargs)
model = _train(client, data=X, label=y, params=params, model_factory=model_factory,
sample_weight=sample_weight, group=group, **kwargs)
self.set_params(**model.get_params())
self._copy_extra_params(model, self)
......@@ -335,8 +359,8 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier):
"""Distributed version of lightgbm.LGBMClassifier."""
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
return self._fit(LGBMClassifier, X, y, sample_weight, client, **kwargs)
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(LGBMClassifier, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs)
fit.__doc__ = LGBMClassifier.fit.__doc__
def predict(self, X, **kwargs):
......@@ -364,7 +388,7 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor):
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(LGBMRegressor, X, y, sample_weight, client, **kwargs)
return self._fit(LGBMRegressor, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs)
fit.__doc__ = LGBMRegressor.fit.__doc__
def predict(self, X, **kwargs):
......@@ -380,3 +404,29 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor):
model : lightgbm.LGBMRegressor
"""
return self._to_local(LGBMRegressor)
class DaskLGBMRanker(_LGBMModel, LGBMRanker):
"""Docstring is inherited from the lightgbm.LGBMRanker."""
def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')
return self._fit(LGBMRanker, X=X, y=y, sample_weight=sample_weight, group=group, client=client, **kwargs)
fit.__doc__ = LGBMRanker.fit.__doc__
def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(self.to_local(), X, **kwargs)
predict.__doc__ = LGBMRanker.predict.__doc__
def to_local(self):
"""Create regular version of lightgbm.LGBMRanker from the distributed version.
Returns
-------
model : lightgbm.LGBMRanker
"""
return self._to_local(LGBMRanker)
# coding: utf-8
"""Tests for lightgbm.dask module"""
import itertools
import os
import socket
import sys
import pytest
if not sys.platform.startswith("linux"):
pytest.skip("lightgbm.dask is currently supported in Linux environments", allow_module_level=True)
if not sys.platform.startswith('linux'):
pytest.skip('lightgbm.dask is currently supported in Linux environments', allow_module_level=True)
import dask.array as da
import dask.dataframe as dd
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
import scipy.sparse
from dask.array.utils import assert_eq
from dask_ml.metrics import accuracy_score, r2_score
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
from sklearn.datasets import make_blobs, make_regression
from sklearn.utils import check_random_state
import lightgbm
import lightgbm.dask as dlgbm
data_output = ['array', 'scipy_csr_matrix', 'dataframe']
data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]]
group_sizes = [5, 5, 5, 10, 10, 10, 20, 20, 20, 50, 50]
pytestmark = [
pytest.mark.skipif(os.getenv("TASK", "") == "mpi", reason="Fails to run with MPI interface")
pytest.mark.skipif(os.getenv('TASK', '') == 'mpi', reason='Fails to run with MPI interface'),
pytest.mark.skipif(os.getenv('TASK', '') == 'gpu', reason='Fails to run with GPU interface')
]
......@@ -37,6 +44,135 @@ def listen_port():
listen_port.port = 13000
def _make_ranking(n_samples=100, n_features=20, n_informative=5, gmax=2,
group=None, random_gs=False, avg_gs=10, random_state=0):
"""Generate a learning-to-rank dataset - feature vectors grouped together with
integer-valued graded relevance scores. Replace this with a sklearn.datasets function
if ranking objective becomes supported in sklearn.datasets module.
Parameters
----------
n_samples : int, optional (default=100)
Total number of documents (records) in the dataset.
n_features : int, optional (default=20)
Total number of features in the dataset.
n_informative : int, optional (default=5)
Number of features that are "informative" for ranking, as they are bias + beta * y
where bias and beta are standard normal variates. If this is greater than n_features, the dataset will have
n_features features, all will be informative.
group : array-like, optional (default=None)
1-d array or list of group sizes. When `group` is specified, this overrides n_samples, random_gs, and
avg_gs by simply creating groups with sizes group[0], ..., group[-1].
gmax : int, optional (default=2)
Maximum graded relevance value for creating relevance/target vector. If you set this to 2, for example, all
documents in a group will have relevance scores of either 0, 1, or 2.
random_gs : bool, optional (default=False)
True will make group sizes ~ Poisson(avg_gs), False will make group sizes == avg_gs.
avg_gs : int, optional (default=10)
Average number of documents (records) in each group.
Returns
-------
X : 2-d np.ndarray of shape = [n_samples (or np.sum(group)), n_features]
Input feature matrix for ranking objective.
y : 1-d np.array of shape = [n_samples (or np.sum(group))]
Integer-graded relevance scores.
group_ids : 1-d np.array of shape = [n_samples (or np.sum(group))]
Array of group ids, each value indicates to which group each record belongs.
"""
rnd_generator = check_random_state(random_state)
y_vec, group_id_vec = np.empty((0,), dtype=int), np.empty((0,), dtype=int)
gid = 0
# build target, group ID vectors.
relvalues = range(gmax + 1)
# build y/target and group-id vectors with user-specified group sizes.
if group is not None and hasattr(group, '__len__'):
n_samples = np.sum(group)
for i, gsize in enumerate(group):
y_vec = np.concatenate((y_vec, rnd_generator.choice(relvalues, size=gsize, replace=True)))
group_id_vec = np.concatenate((group_id_vec, [i] * gsize))
# build y/target and group-id vectors according to n_samples, avg_gs, and random_gs.
else:
while len(y_vec) < n_samples:
gsize = avg_gs if not random_gs else rnd_generator.poisson(avg_gs)
# groups should contain > 1 element for pairwise learning objective.
if gsize < 1:
continue
y_vec = np.append(y_vec, rnd_generator.choice(relvalues, size=gsize, replace=True))
group_id_vec = np.append(group_id_vec, [gid] * gsize)
gid += 1
y_vec, group_id_vec = y_vec[:n_samples], group_id_vec[:n_samples]
# build feature data, X. Transform first few into informative features.
n_informative = max(min(n_features, n_informative), 0)
X = rnd_generator.uniform(size=(n_samples, n_features))
for j in range(n_informative):
bias, coef = rnd_generator.normal(size=2)
X[:, j] = bias + coef * y_vec
return X, y_vec, group_id_vec
def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs):
X, y, g = _make_ranking(n_samples=n_samples, random_state=42, **kwargs)
rnd = np.random.RandomState(42)
w = rnd.rand(X.shape[0]) * 0.01
g_rle = np.array([len(list(grp)) for _, grp in itertools.groupby(g)])
if output == 'dataframe':
# add target, weight, and group to DataFrame so that partitions abide by group boundaries.
X_df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(X.shape[1])])
X = X_df.copy()
X_df = X_df.assign(y=y, g=g, w=w)
# set_index ensures partitions are based on group id.
# See https://stackoverflow.com/questions/49532824/dask-dataframe-split-partitions-based-on-a-column-or-function.
X_df.set_index('g', inplace=True)
dX = dd.from_pandas(X_df, chunksize=chunk_size)
# separate target, weight from features.
dy = dX['y']
dw = dX['w']
dX = dX.drop(columns=['y', 'w'])
dg = dX.index.to_series()
# encode group identifiers into run-length encoding, the format LightGBMRanker is expecting
# so that within each partition, sum(g) = n_samples.
dg = dg.map_partitions(lambda p: p.groupby('g', sort=False).apply(lambda z: z.shape[0]))
elif output == 'array':
# ranking arrays: one chunk per group. Each chunk must include all columns.
p = X.shape[1]
dX, dy, dw, dg = [], [], [], []
for g_idx, rhs in enumerate(np.cumsum(g_rle)):
lhs = rhs - g_rle[g_idx]
dX.append(da.from_array(X[lhs:rhs, :], chunks=(rhs - lhs, p)))
dy.append(da.from_array(y[lhs:rhs]))
dw.append(da.from_array(w[lhs:rhs]))
dg.append(da.from_array(np.array([g_rle[g_idx]])))
dX = da.concatenate(dX, axis=0)
dy = da.concatenate(dy, axis=0)
dw = da.concatenate(dw, axis=0)
dg = da.concatenate(dg, axis=0)
else:
raise ValueError('Ranking data creation only supported for Dask arrays and dataframes')
return X, y, w, g_rle, dX, dy, dw, dg
def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size=50):
if objective == 'classification':
X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42)
......@@ -96,6 +232,8 @@ def test_classifier(output, centers, client, listen_port):
assert_eq(y, p2)
assert_eq(p1_proba, p2_proba, atol=0.3)
client.close()
def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('classification', output='array')
......@@ -118,6 +256,8 @@ def test_training_does_not_fail_on_port_conflicts(client):
)
assert dask_classifier.booster_
client.close()
def test_classifier_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output='array')
......@@ -139,6 +279,8 @@ def test_classifier_local_predict(client, listen_port):
assert_eq(y, p1)
assert_eq(y, p2)
client.close()
@pytest.mark.parametrize('output', data_output)
def test_regressor(output, client, listen_port):
......@@ -170,6 +312,8 @@ def test_regressor(output, client, listen_port):
assert_eq(y, p1, rtol=1., atol=100.)
assert_eq(y, p2, rtol=1., atol=50.)
client.close()
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('alpha', [.1, .5, .9])
......@@ -204,6 +348,8 @@ def test_regressor_quantile(output, client, listen_port, alpha):
np.testing.assert_allclose(q1, alpha, atol=0.2)
np.testing.assert_allclose(q2, alpha, atol=0.2)
client.close()
def test_regressor_local_predict(client, listen_port):
X, y, _, dX, dy, dw = _create_data('regression', output='array')
......@@ -226,6 +372,54 @@ def test_regressor_local_predict(client, listen_port):
assert_eq(p1, p2)
assert_eq(s1, s2)
client.close()
@pytest.mark.parametrize('output', ['array', 'dataframe'])
@pytest.mark.parametrize('group', [None, group_sizes])
def test_ranker(output, client, listen_port, group):
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output, group=group)
# use many trees + leaves to overfit, help ensure that dask data-parallel strategy matches that of
# serial learner. See https://github.com/microsoft/LightGBM/issues/3292#issuecomment-671288210.
dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, tree_learner_type='data_parallel',
n_estimators=50, num_leaves=20, seed=42, min_child_samples=1)
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client)
rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute()
local_ranker = lightgbm.LGBMRanker(n_estimators=50, num_leaves=20, seed=42, min_child_samples=1)
local_ranker.fit(X, y, sample_weight=w, group=g)
rnkvec_local = local_ranker.predict(X)
# distributed ranker should be able to rank decently well and should
# have high rank correlation with scores from serial ranker.
dcor = spearmanr(rnkvec_dask, y).correlation
assert dcor > 0.6
assert spearmanr(rnkvec_dask, rnkvec_local).correlation > 0.9
client.close()
@pytest.mark.parametrize('output', ['array', 'dataframe'])
@pytest.mark.parametrize('group', [None, group_sizes])
def test_ranker_local_predict(output, client, listen_port, group):
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output, group=group)
dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, tree_learner='data',
n_estimators=10, num_leaves=10, seed=42, min_child_samples=1)
dask_ranker = dask_ranker.fit(dX, dy, group=dg, client=client)
rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute()
rnkvec_local = dask_ranker.to_local().predict(X)
# distributed and to-local scores should be the same.
assert_eq(rnkvec_dask, rnkvec_local)
client.close()
def test_find_open_port_works():
worker_ip = '127.0.0.1'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment