Commit 6891983c authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

add boosting_type to python sklearn (#121)

* add python doc

* add boosting_type to python sklearn

* fix error
parent 54963bb7
...@@ -532,12 +532,15 @@ ...@@ -532,12 +532,15 @@
##Scikit-learn API ##Scikit-learn API
---- ----
###Common Methods ###Common Methods
####__init__(num_leaves=31, max_depth=-1, learning_rate=0.1, n_estimators=10, max_bin=255, silent=True, objective=regression, nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10, subsample=1, subsample_freq=1, colsample_bytree=1, reg_alpha=0, reg_lambda=0, scale_pos_weight=1, is_unbalance=False, seed=0) ####__init__(boosting_type="gbdt", num_leaves=31, max_depth=-1, learning_rate=0.1, n_estimators=10, max_bin=255, silent=True, objective=regression, nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10, subsample=1, subsample_freq=1, colsample_bytree=1, reg_alpha=0, reg_lambda=0, scale_pos_weight=1, is_unbalance=False, seed=0)
Implementation of the Scikit-Learn API for LightGBM. Implementation of the Scikit-Learn API for LightGBM.
Parameters Parameters
---------- ----------
boosting_type : string
gbdt, traditional Gradient Boosting Decision Tree
dart, Dropouts meet Multiple Additive Regression Trees
num_leaves : int num_leaves : int
Maximum tree leaves for base learners. Maximum tree leaves for base learners.
max_depth : int max_depth : int
......
...@@ -132,7 +132,7 @@ def reset_learning_rate(learning_rates): ...@@ -132,7 +132,7 @@ def reset_learning_rate(learning_rates):
env.model.reset_parameter({"learning_rate": \ env.model.reset_parameter({"learning_rate": \
learning_rates(env.iteration - env.begin_iteration, env.end_iteration - env.begin_iteration)}) learning_rates(env.iteration - env.begin_iteration, env.end_iteration - env.begin_iteration)})
else: else:
raise ValueError("Self-defined function 'learning_rates' should have 1 or 2 arguments") raise ValueError("Self-defined function 'learning_rates' should have 1 or 2 arguments, got %d" %(argc))
callback.before_iteration = True callback.before_iteration = True
callback.order = 10 callback.order = 10
return callback return callback
......
# coding: utf-8 # coding: utf-8
# pylint: disable = invalid-name, W0105, C0111 # pylint: disable = invalid-name, W0105, C0111, C0301
"""Scikit-Learn Wrapper interface for LightGBM.""" """Scikit-Learn Wrapper interface for LightGBM."""
from __future__ import absolute_import from __future__ import absolute_import
import inspect import inspect
...@@ -61,7 +61,7 @@ def _objective_function_wrapper(func): ...@@ -61,7 +61,7 @@ def _objective_function_wrapper(func):
elif argc == 3: elif argc == 3:
grad, hess = func(labels, preds, dataset.get_group()) grad, hess = func(labels, preds, dataset.get_group())
else: else:
raise TypeError("parameter number of objective function should be (2, 3), got %d" %(argc)) raise TypeError("Self-defined objective function should have 2 or 3 arguments, got %d" %(argc))
"""weighted for objective""" """weighted for objective"""
weight = dataset.get_weight() weight = dataset.get_weight()
if weight is not None: if weight is not None:
...@@ -89,8 +89,11 @@ def _eval_function_wrapper(func): ...@@ -89,8 +89,11 @@ def _eval_function_wrapper(func):
Parameters Parameters
---------- ----------
func: callable func: callable
Expects a callable with following functions: ``func(y_true, y_pred)``, ``func(y_true, y_pred, weight)`` Expects a callable with following functions:
or ``func(y_true, y_pred, weight, group)`` and return (eval_name->str, eval_result->float, is_bigger_better->Bool): ``func(y_true, y_pred)``,
``func(y_true, y_pred, weight)``
or ``func(y_true, y_pred, weight, group)``
and return (eval_name->str, eval_result->float, is_bigger_better->Bool):
y_true: array_like of shape [n_samples] y_true: array_like of shape [n_samples]
The target values The target values
...@@ -124,12 +127,12 @@ def _eval_function_wrapper(func): ...@@ -124,12 +127,12 @@ def _eval_function_wrapper(func):
elif argc == 4: elif argc == 4:
return func(labels, preds, dataset.get_weight(), dataset.get_group()) return func(labels, preds, dataset.get_weight(), dataset.get_group())
else: else:
raise TypeError("parameter number of eval function should be (2, 3, 4), got %d" %(argc)) raise TypeError("Self-defined eval function should have 2, 3 or 4 arguments, got %d" %(argc))
return inner return inner
class LGBMModel(LGBMModelBase): class LGBMModel(LGBMModelBase):
def __init__(self, num_leaves=31, max_depth=-1, def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255, learning_rate=0.1, n_estimators=10, max_bin=255,
silent=True, objective="regression", silent=True, objective="regression",
nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10, nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10,
...@@ -141,6 +144,9 @@ class LGBMModel(LGBMModelBase): ...@@ -141,6 +144,9 @@ class LGBMModel(LGBMModelBase):
Parameters Parameters
---------- ----------
boosting_type : string
gbdt, traditional Gradient Boosting Decision Tree
dart, Dropouts meet Multiple Additive Regression Trees
num_leaves : int num_leaves : int
Maximum tree leaves for base learners. Maximum tree leaves for base learners.
max_depth : int max_depth : int
...@@ -205,6 +211,7 @@ class LGBMModel(LGBMModelBase): ...@@ -205,6 +211,7 @@ class LGBMModel(LGBMModelBase):
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for this module') raise LightGBMError('Scikit-learn is required for this module')
self.boosting_type = boosting_type
self.num_leaves = num_leaves self.num_leaves = num_leaves
self.max_depth = max_depth self.max_depth = max_depth
self.learning_rate = learning_rate self.learning_rate = learning_rate
...@@ -476,14 +483,14 @@ class LGBMRegressor(LGBMModel, LGBMRegressorBase): ...@@ -476,14 +483,14 @@ class LGBMRegressor(LGBMModel, LGBMRegressorBase):
class LGBMClassifier(LGBMModel, LGBMClassifierBase): class LGBMClassifier(LGBMModel, LGBMClassifierBase):
def __init__(self, num_leaves=31, max_depth=-1, def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255, learning_rate=0.1, n_estimators=10, max_bin=255,
silent=True, objective="binary", silent=True, objective="binary",
nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10, nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10,
subsample=1, subsample_freq=1, colsample_bytree=1, subsample=1, subsample_freq=1, colsample_bytree=1,
reg_alpha=0, reg_lambda=0, scale_pos_weight=1, reg_alpha=0, reg_lambda=0, scale_pos_weight=1,
is_unbalance=False, seed=0): is_unbalance=False, seed=0):
super(LGBMClassifier, self).__init__(num_leaves, max_depth, super(LGBMClassifier, self).__init__(boosting_type, num_leaves, max_depth,
learning_rate, n_estimators, max_bin, learning_rate, n_estimators, max_bin,
silent, objective, nthread, silent, objective, nthread,
min_split_gain, min_child_weight, min_child_samples, min_split_gain, min_child_weight, min_child_samples,
...@@ -561,14 +568,14 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): ...@@ -561,14 +568,14 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
class LGBMRanker(LGBMModel): class LGBMRanker(LGBMModel):
def __init__(self, num_leaves=31, max_depth=-1, def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255, learning_rate=0.1, n_estimators=10, max_bin=255,
silent=True, objective="lambdarank", silent=True, objective="lambdarank",
nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10, nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10,
subsample=1, subsample_freq=1, colsample_bytree=1, subsample=1, subsample_freq=1, colsample_bytree=1,
reg_alpha=0, reg_lambda=0, scale_pos_weight=1, reg_alpha=0, reg_lambda=0, scale_pos_weight=1,
is_unbalance=False, seed=0): is_unbalance=False, seed=0):
super(LGBMRanker, self).__init__(num_leaves, max_depth, super(LGBMRanker, self).__init__(boosting_type, num_leaves, max_depth,
learning_rate, n_estimators, max_bin, learning_rate, n_estimators, max_bin,
silent, objective, nthread, silent, objective, nthread,
min_split_gain, min_child_weight, min_child_samples, min_split_gain, min_child_weight, min_child_samples,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment