Unverified Commit fe1b80a5 authored by jmoralez's avatar jmoralez Committed by GitHub
Browse files

[dask] Include support for raw_score in predict (fixes #3793) (#4024)

* include test for prediction with raw_score

* close client

* initial comments

* update data creation and include ranking task

* linting

* update _create_data

* compare unique raw_predictions with values in leaves_df
parent 8cc6eefc
......@@ -449,7 +449,7 @@ def _predict_part(
# dask.DataFrame.map_partitions() expects each call to return a pandas DataFrame or Series
if isinstance(part, pd_DataFrame):
if pred_proba or pred_contrib or pred_leaf:
if len(result.shape) == 2:
result = pd_DataFrame(result, index=part.index)
else:
result = pd_Series(result, index=part.index, name='predictions')
......@@ -510,10 +510,6 @@ def _predict(
**kwargs
).values
elif isinstance(data, dask_Array):
if pred_proba:
kwargs['chunks'] = (data.chunks[0], (model.n_classes_,))
else:
kwargs['drop_axis'] = 1
return data.map_blocks(
_predict_part,
model=model,
......@@ -522,7 +518,7 @@ def _predict(
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
dtype=dtype,
**kwargs
drop_axis=1
)
else:
raise TypeError('Data must be either Dask Array or Dask DataFrame. Got %s.' % str(type(data)))
......
......@@ -1265,3 +1265,43 @@ def test_parameters_default_constructible(estimator):
else:
Estimator = estimator.__class__
sklearn_checks.check_parameters_default_constructible(name, Estimator)
@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_predict_with_raw_score(task, output, client):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
_, _, _, _, dX, dy, _, dg = _create_data(
objective=task,
output=output,
group=None
)
model_factory = task_to_dask_factory[task]
params = {
'client': client,
'n_estimators': 1,
'num_leaves': 2,
'time_out': 5,
'min_sum_hessian': 0
}
model = model_factory(**params)
model.fit(dX, dy, group=dg)
raw_predictions = model.predict(dX, raw_score=True).compute()
trees_df = model.booster_.trees_to_dataframe()
leaves_df = trees_df[trees_df.node_depth == 2]
if task == 'multiclass-classification':
for i in range(model.n_classes_):
class_df = leaves_df[leaves_df.tree_index == i]
assert set(raw_predictions[:, i]) == set(class_df['value'])
else:
assert set(raw_predictions) == set(leaves_df['value'])
if task.endswith('classification'):
pred_proba_raw = model.predict_proba(dX, raw_score=True).compute()
assert_eq(raw_predictions, pred_proba_raw)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment