Commit 5b539788 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix some pep8 check

parent 1a8c23ed
......@@ -4,8 +4,6 @@ from __future__ import absolute_import
import sys
import os
import ctypes
import collections
import re
import tempfile
import numpy as np
......@@ -59,7 +57,7 @@ def is_1d_list(data):
if not isinstance(data, list):
return False
if len(data) > 0:
if not isinstance(data[0], (int, float, bool) ):
if not isinstance(data[0], (int, float, bool)):
return False
return True
......@@ -108,29 +106,29 @@ def param_dict_to_str(data):
if is_str(val):
pairs.append(str(key)+'='+str(val))
elif isinstance(val, (list, tuple)):
pairs.append(str(key)+'='+','.join(map(str,val)))
pairs.append(str(key)+'='+','.join(map(str, val)))
elif isinstance(val, (int, float, bool)):
pairs.append(str(key)+'='+str(val))
else:
raise TypeError('unknow type of parameter:%s , got:%s' %(key, type(val).__name__))
raise TypeError('unknow type of parameter:%s , got:%s'
% (key, type(val).__name__))
return ' '.join(pairs)
"""marco definition of data type in c_api of LightGBM"""
C_API_DTYPE_FLOAT32 =0
C_API_DTYPE_FLOAT64 =1
C_API_DTYPE_INT32 =2
C_API_DTYPE_INT64 =3
C_API_DTYPE_FLOAT32 = 0
C_API_DTYPE_FLOAT64 = 1
C_API_DTYPE_INT32 = 2
C_API_DTYPE_INT64 = 3
"""Matric is row major in python"""
C_API_IS_ROW_MAJOR =1
C_API_IS_ROW_MAJOR = 1
C_API_PREDICT_NORMAL =0
C_API_PREDICT_RAW_SCORE =1
C_API_PREDICT_LEAF_INDEX =2
C_API_PREDICT_NORMAL = 0
C_API_PREDICT_RAW_SCORE = 1
C_API_PREDICT_LEAF_INDEX = 2
FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32,
"weight":C_API_DTYPE_FLOAT32,
"init_score":C_API_DTYPE_FLOAT32,
"group":C_API_DTYPE_INT32,
}
FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32,
"weight": C_API_DTYPE_FLOAT32,
"init_score": C_API_DTYPE_FLOAT32,
"group": C_API_DTYPE_INT32}
def c_float_array(data):
"""Convert numpy array / list to c float array."""
......@@ -144,7 +142,8 @@ def c_float_array(data):
ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
type_data = C_API_DTYPE_FLOAT64
else:
raise TypeError("expected np.float32 or np.float64, met type({})".format(data.dtype))
raise TypeError("expected np.float32 or np.float64, met type({})"
.format(data.dtype))
else:
raise TypeError("Unknow type({})".format(type(data).__name__))
return (ptr_data, type_data)
......@@ -161,7 +160,8 @@ def c_int_array(data):
ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int64))
type_data = C_API_DTYPE_INT64
else:
raise TypeError("expected np.int32 or np.int64, met type({})".format(data.dtype))
raise TypeError("expected np.int32 or np.int64, met type({})"
.format(data.dtype))
else:
raise TypeError("Unknow type({})".format(type(data).__name__))
return (ptr_data, type_data)
......@@ -169,13 +169,13 @@ def c_int_array(data):
class Predictor(object):
""""A Predictor of LightGBM.
"""
def __init__(self,model_file=None, booster_handle=None, is_manage_handle=True):
def __init__(self, model_file=None, booster_handle=None, is_manage_handle=True):
"""Initialize the Predictor.
Parameters
----------
model_file : string
Path to the model file.
Path to the model file.
"""
self.handle = ctypes.c_void_p()
self.__is_manage_handle = True
......@@ -191,7 +191,7 @@ class Predictor(object):
self.handle,
ctypes.byref(out_num_class)))
self.num_class = out_num_class.value
self.__num_total_iteration = out_num_iterations.value
self.__num_total_iteration = out_num_iterations.value
elif booster_handle is not None:
self.__is_manage_handle = is_manage_handle
self.handle = booster_handle
......@@ -204,7 +204,7 @@ class Predictor(object):
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle,
ctypes.byref(out_num_iterations)))
self.__num_total_iteration = out_num_iterations.value
self.__num_total_iteration = out_num_iterations.value
else:
raise TypeError('Need Model file to create a booster')
......@@ -213,7 +213,9 @@ class Predictor(object):
_safe_call(_LIB.LGBM_BoosterFree(self.handle))
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True):
def predict(self, data, num_iteration=-1,
raw_score=False, pred_leaf=False, data_has_header=False,
is_reshape=True):
"""
Predict logic
......@@ -222,23 +224,24 @@ class Predictor(object):
data : string/numpy array/scipy.sparse
Data source for prediction
When data is string type, it represents the path of txt file,
num_iteration :
num_iteration : int
used iteration for prediction
raw_score : bool
raw_score : bool
True for predict raw score
pred_leaf : bool
True for predict leaf index
data_has_header : bool
Used for txt data
is_reshape : bool
True for reshape to [nrow, ...]
True for reshape to [nrow, ...]
Returns
-------
Prediction result
"""
if isinstance(data, Dataset):
raise TypeError("cannot use Dataset instance for prediction, please use raw data instead")
raise TypeError("cannot use Dataset instance for prediction, \
please use raw data instead")
predict_type = C_API_PREDICT_NORMAL
if raw_score:
predict_type = C_API_PREDICT_RAW_SCORE
......@@ -251,12 +254,12 @@ class Predictor(object):
tmp_pred_fname = tempfile.NamedTemporaryFile(prefix="lightgbm_tmp_pred_").name
_safe_call(_LIB.LGBM_BoosterPredictForFile(
self.handle,
c_str(data),
c_str(data),
int_data_has_header,
predict_type,
num_iteration,
c_str(tmp_pred_fname)))
tmp_file = open(tmp_pred_fname,"r")
tmp_file = open(tmp_pred_fname, "r")
lines = tmp_file.readlines()
tmp_file.close()
nrow = len(lines)
......@@ -267,15 +270,19 @@ class Predictor(object):
preds = np.array(preds, copy=False)
os.remove(tmp_pred_fname)
elif isinstance(data, scipy.sparse.csr_matrix):
preds, nrow = self.__pred_for_csr(data, num_iteration, predict_type)
preds, nrow = self.__pred_for_csr(data, num_iteration,
predict_type)
elif isinstance(data, np.ndarray):
preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type)
preds, nrow = self.__pred_for_np2d(data, num_iteration,
predict_type)
else:
try:
csr = scipy.sparse.csr_matrix(data)
preds, nrow = self.__pred_for_csr(csr, num_iteration, predict_type)
preds, nrow = self.__pred_for_csr(csr, num_iteration,
predict_type)
except:
raise TypeError('can not predict data for type {}'.format(type(data).__name__))
raise TypeError('can not predict data for type {}'.
format(type(data).__name__))
if pred_leaf:
preds = preds.astype(np.int32)
if preds.size != nrow and is_reshape:
......@@ -283,7 +290,8 @@ class Predictor(object):
ncol = int(preds.size / nrow)
preds = preds.reshape(nrow, ncol)
else:
raise ValueError('len of predict result(%d) cannot be divide nrow(%d)' %(preds.size, nrow) )
raise ValueError('len of predict result(%d) cannot be divide nrow (%d)'
% (preds.size, nrow))
return preds
def __get_num_preds(self, num_iteration, nrow, predict_type):
......@@ -308,12 +316,13 @@ class Predictor(object):
"""change non-float data to float data, need to copy"""
data = np.array(mat.reshape(mat.size), dtype=np.float32)
ptr_data, type_ptr_data = c_float_array(data)
n_preds = self.__get_num_preds(num_iteration, mat.shape[0], predict_type)
n_preds = self.__get_num_preds(num_iteration, mat.shape[0],
predict_type)
preds = np.zeros(n_preds, dtype=np.float32)
out_num_preds = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterPredictForMat(
self.handle,
ptr_data,
ptr_data,
type_ptr_data,
mat.shape[0],
mat.shape[1],
......@@ -341,12 +350,12 @@ class Predictor(object):
_safe_call(_LIB.LGBM_BoosterPredictForCSR(
self.handle,
ptr_indptr,
ptr_indptr,
type_ptr_indptr,
csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data,
type_ptr_data,
len(csr.indptr),
type_ptr_data,
len(csr.indptr),
len(csr.data),
csr.shape[1],
predict_type,
......@@ -365,10 +374,10 @@ except ImportError:
class DataFrame(object):
pass
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int',
'int64': 'int', 'uint8': 'int', 'uint16': 'int',
'uint32': 'int', 'uint64': 'int', 'float16': 'float',
'float32': 'float', 'float64': 'float', 'bool': 'i'}
def _data_from_pandas(data):
if isinstance(data, DataFrame):
......@@ -399,8 +408,8 @@ class Dataset(object):
"""
def __init__(self, data, label=None, max_bin=255, reference=None,
weight=None, group=None, predictor=None,
silent=False, params=None):
weight=None, group=None, predictor=None,
silent=False, params=None):
"""
Dataset used in LightGBM.
......@@ -412,7 +421,7 @@ class Dataset(object):
label : list or numpy 1-D array, optional
Label of the data
max_bin : int, required
max number of discrete bin for features
max number of discrete bin for features
reference : Other Dataset, optional
If this dataset validation, need to use training data as reference
weight : list or numpy 1-D array , optional
......@@ -482,10 +491,10 @@ class Dataset(object):
self.set_group(group)
# load init score
if self.predictor is not None and isinstance(self.predictor, Predictor):
init_score = self.predictor.predict(data,
raw_score=True,
data_has_header=self.data_has_header,
is_reshape=False)
init_score = self.predictor.predict(data,
raw_score=True,
data_has_header=self.data_has_header,
is_reshape=False)
if self.predictor.num_class > 1:
# need re group init score
new_init_score = np.zeros(init_score.size(), dtype=np.float32)
......@@ -496,8 +505,8 @@ class Dataset(object):
init_score = new_init_score
self.set_init_score(init_score)
def create_valid(self, data, label=None, weight=None, group=None,
silent=False, params=None):
def create_valid(self, data, label=None, weight=None, group=None,
silent=False, params=None):
"""
Create validation data align with current dataset
......@@ -518,8 +527,8 @@ class Dataset(object):
other parameters
"""
return Dataset(data, label=label, max_bin=self.max_bin, reference=self,
weight=weight, group=group, predictor=self.predictor,
silent=silent, params=params)
weight=weight, group=group, predictor=self.predictor,
silent=silent, params=params)
def subset(self, used_indices, params=None):
"""
......@@ -530,10 +539,10 @@ class Dataset(object):
ret.handle = ctypes.c_void_p()
params_str = param_dict_to_str(params)
_safe_call(_LIB.LGBM_DatasetGetSubset(
ctypes.byref(self.handle),
ctypes.byref(self.handle),
used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
used_indices.shape[0],
c_str(params_str),
c_str(params_str),
ctypes.byref(ret.handle)))
ret.max_bin = self.max_bin
ret.predictor = self.predictor
......@@ -557,13 +566,13 @@ class Dataset(object):
ptr_data, type_ptr_data = c_float_array(data)
_safe_call(_LIB.LGBM_DatasetCreateFromMat(
ptr_data,
ptr_data,
type_ptr_data,
mat.shape[0],
mat.shape[1],
C_API_IS_ROW_MAJOR,
c_str(params_str),
ref_dataset,
c_str(params_str),
ref_dataset,
ctypes.byref(self.handle)))
def __init_from_csr(self, csr, params_str, ref_dataset):
......@@ -578,16 +587,16 @@ class Dataset(object):
ptr_data, type_ptr_data = c_float_array(csr.data)
_safe_call(_LIB.LGBM_DatasetCreateFromCSR(
ptr_indptr,
ptr_indptr,
type_ptr_indptr,
csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data,
type_ptr_data,
len(csr.indptr),
type_ptr_data,
len(csr.indptr),
len(csr.data),
csr.shape[1],
c_str(params_str),
ref_dataset,
csr.shape[1],
c_str(params_str),
ref_dataset,
ctypes.byref(self.handle)))
def __del__(self):
......@@ -784,7 +793,7 @@ class Dataset(object):
"""
ret = ctypes.c_int64()
_safe_call(_LIB.LGBM_DatasetGetNumData(self.handle,
ctypes.byref(ret)))
ctypes.byref(ret)))
return ret.value
def num_feature(self):
......@@ -796,7 +805,7 @@ class Dataset(object):
"""
ret = ctypes.c_int64()
_safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle,
ctypes.byref(ret)))
ctypes.byref(ret)))
return ret.value
class Booster(object):
......@@ -812,7 +821,7 @@ class Booster(object):
train_set : Dataset
training dataset
model_file : string
Path to the model file.
Path to the model file.
silent : boolean, optional
Whether print messages during construction
"""
......@@ -833,7 +842,7 @@ class Booster(object):
params_str = param_dict_to_str(params)
"""construct booster object"""
_safe_call(_LIB.LGBM_BoosterCreate(
train_set.handle,
train_set.handle,
c_str(params_str),
ctypes.byref(self.handle)))
"""save reference to data"""
......@@ -859,7 +868,7 @@ class Booster(object):
"""Prediction task"""
out_num_iterations = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
c_str(model_file),
c_str(model_file),
ctypes.byref(out_num_iterations),
ctypes.byref(self.handle)))
out_num_class = ctypes.c_int64(0)
......@@ -939,13 +948,13 @@ class Booster(object):
raise Exception("Replace training data failed, you should use same predictor for these data")
self.train_set = train_set
_safe_call(_LIB.LGBM_BoosterResetTrainingData(
self.handle,
self.handle,
self.train_set.handle))
self.__inner_predict_buffer[0] = None
is_finished = ctypes.c_int(0)
if fobj is None:
_safe_call(_LIB.LGBM_BoosterUpdateOneIter(
self.handle,
self.handle,
ctypes.byref(is_finished)))
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
return is_finished.value == 1
......@@ -1080,7 +1089,7 @@ class Booster(object):
Parameters
----------
filename : str
filename to save
filename to save
num_iteration: int
number of iteration that want to save. < 0 means save all
"""
......@@ -1098,16 +1107,16 @@ class Booster(object):
data : string/numpy array/scipy.sparse
Data source for prediction
When data is string type, it represents the path of txt file,
num_iteration :
num_iteration : int
used iteration for prediction
raw_score : bool
raw_score : bool
True for predict raw score
pred_leaf : bool
True for predict leaf index
data_has_header : bool
Used for txt data
is_reshape : bool
True for reshape to [nrow, ...]
True for reshape to [nrow, ...]
Returns
-------
......@@ -1136,8 +1145,8 @@ class Booster(object):
result = np.array([0.0 for _ in range(self.__num_inner_eval)], dtype=np.float32)
tmp_out_len = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterGetEval(
self.handle,
data_idx,
self.handle,
data_idx,
ctypes.byref(tmp_out_len),
result.ctypes.data_as(ctypes.POINTER(ctypes.c_float))))
if tmp_out_len.value != self.__num_inner_eval:
......@@ -1176,12 +1185,12 @@ class Booster(object):
tmp_out_len = ctypes.c_int64(0)
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_safe_call(_LIB.LGBM_BoosterGetPredict(
self.handle,
data_idx,
ctypes.byref(tmp_out_len),
self.handle,
data_idx,
ctypes.byref(tmp_out_len),
data_ptr))
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):
raise ValueError("incorrect number of predict results for data %d" %(data_idx) )
raise ValueError("incorrect number of predict results for data %d" % (data_idx) )
self.__is_predicted_cur_iter[data_idx] = True
return self.__inner_predict_buffer[data_idx]
......
......@@ -148,7 +148,6 @@ def early_stop(stopping_rounds, verbose=True):
callback : function
The requested callback function.
"""
state = {}
factor_to_bigger_better = {}
best_score = {}
best_iter = {}
......@@ -172,23 +171,21 @@ def early_stop(stopping_rounds, verbose=True):
factor_to_bigger_better[i] = -1.0
if env.evaluation_result_list[i][3]:
factor_to_bigger_better[i] = 1.0
state['best_iter'] = 0
def callback(env):
"""internal function"""
if len(best_score) == 0:
init(env)
for i in range(len(env.evaluation_result_list)):
for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i]
if score > best_score[i]:
best_score[i] = score
best_iter[i] = env.iteration
if verbose:
best_msg[i] = '[%d]\t%s' % ( env.iteration,
best_msg[i] = '[%d]\t%s' % ( env.iteration,
'\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
else:
if env.iteration - best_iter[i] >= stopping_rounds:
state['best_iter'] = best_iter[i]
if env.model is not None:
env.model.set_attr(best_iteration=str(best_iter[i]))
if verbose:
......
"""Training Library containing training routines of LightGBM."""
from __future__ import absolute_import
import collections
import numpy as np
from .basic import LightGBMError, Predictor, Dataset, Booster, is_str
from . import callback
def _construct_dataset(X_y, reference=None,
params=None, other_fields=None, predictor=None):
params=None, other_fields=None,
predictor=None):
if 'max_bin' in params:
max_bin = int(params['max_bin'])
else:
......@@ -31,20 +31,22 @@ def _construct_dataset(X_y, reference=None,
label = X_y[1]
if reference is None:
ret = Dataset(data, label=label, max_bin=max_bin,
weight=weight, group=group, predictor=predictor, params=params)
weight=weight, group=group,
predictor=predictor, params=params)
else:
ret = reference.create_valid(data, label=label, weight=weight, group=group, params=params)
ret = reference.create_valid(data, label=label, weight=weight,
group=group, params=params)
if init_score is not None:
ret.set_init_score(init_score)
return ret
def train(params, train_data, num_boost_round=100,
valid_datas=None, valid_names=None,
fobj=None, feval=None, init_model=None,
train_fields=None, valid_fields=None,
early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None, callbacks=None):
def train(params, train_data, num_boost_round=100,
valid_datas=None, valid_names=None,
fobj=None, feval=None, init_model=None,
train_fields=None, valid_fields=None,
early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None, callbacks=None):
"""Train with given parameters.
Parameters
......@@ -134,9 +136,9 @@ def train(params, train_data, num_boost_round=100,
continue
valid_set = _construct_dataset(
valid_datas[i],
train_set,
params,
other_fields,
train_set,
params,
other_fields,
predictor)
valid_sets.append(valid_set)
if valid_names is not None:
......@@ -182,11 +184,11 @@ def train(params, train_data, num_boost_round=100,
for i in range(num_boost_round):
for cb in callbacks_before_iter:
cb(callback.CallbackEnv(model=booster,
cvfolds=None,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=None))
cvfolds=None,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=None))
booster.update(fobj=fobj)
......@@ -199,11 +201,11 @@ def train(params, train_data, num_boost_round=100,
try:
for cb in callbacks_after_iter:
cb(callback.CallbackEnv(model=booster,
cvfolds=None,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=evaluation_result_list))
cvfolds=None,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=evaluation_result_list))
except callback.EarlyStopException:
break
if booster.attr('best_iteration') is not None:
......@@ -384,11 +386,11 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
for i in range(num_boost_round):
for cb in callbacks_before_iter:
cb(callback.CallbackEnv(model=None,
cvfolds=cvfolds,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=None))
cvfolds=cvfolds,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=None))
for fold in cvfolds:
fold.update(fobj)
res = _agg_cv_result([f.eval(feval) for f in cvfolds])
......@@ -402,13 +404,13 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
try:
for cb in callbacks_after_iter:
cb(callback.CallbackEnv(model=None,
cvfolds=cvfolds,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=res))
cvfolds=cvfolds,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=res))
except callback.EarlyStopException as e:
for k in results.keys():
results[k] = results[k][:(e.state['best_iter'] + 1)]
results[k] = results[k][:(e.best_iteration + 1)]
break
return results
......@@ -12,9 +12,9 @@ def find_lib_path():
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [curr_path, os.path.join(curr_path, '../../lib/'),
os.path.join(curr_path, '../../'),
os.path.join(curr_path, './lib/'),
os.path.join(sys.prefix, 'lightgbm')]
os.path.join(curr_path, '../../'),
os.path.join(curr_path, './lib/'),
os.path.join(sys.prefix, 'lightgbm')]
if os.name == 'nt':
dll_path.append(os.path.join(curr_path, '../../windows/x64/Dll/'))
dll_path.append(os.path.join(curr_path, './windows/x64/Dll/'))
......
......@@ -194,7 +194,8 @@ class LGBMModel(LGBMModelBase):
return params
def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, train_fields=None, valid_fields=None, other_params=None):
early_stopping_rounds=None, verbose=True,
train_fields=None, valid_fields=None, other_params=None):
"""
Fit the gradient boosting model
......@@ -308,7 +309,7 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
""" + '\n'.join(LGBMModel.__doc__.split('\n')[2:])
def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
early_stopping_rounds=None, verbose=True,
train_fields=None, valid_fields=None, other_params=None):
self.classes_ = np.unique(y)
......@@ -328,8 +329,10 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
if eval_set is not None:
eval_set = list( (x[0], self._le.transform(x[1])) for x in eval_set )
super(LGBMClassifier, self).fit(X, training_labels, eval_set, eval_metric,
early_stopping_rounds, verbose, train_fields, valid_fields, other_params)
super(LGBMClassifier, self).fit(X, training_labels, eval_set,
eval_metric, early_stopping_rounds,
verbose, train_fields, valid_fields,
other_params)
return self
def predict(self, data, raw_score=False, num_iteration=0):
......@@ -405,7 +408,7 @@ class LGBMRanker(LGBMModel):
""" + '\n'.join(LGBMModel.__doc__.split('\n')[2:])
def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
early_stopping_rounds=None, verbose=True,
train_fields=None, valid_fields=None, other_params=None):
"""check group data"""
......@@ -428,6 +431,8 @@ class LGBMRanker(LGBMModel):
self.objective = "lambdarank"
self.fobj = None
super(LGBMRanker, self).fit(X, y, eval_set, eval_metric,
early_stopping_rounds, verbose, train_fields, valid_fields, other_params)
super(LGBMRanker, self).fit(X, y, eval_set, eval_metric,
early_stopping_rounds, verbose,
train_fields, valid_fields,
other_params)
return self
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment