"python-package/vscode:/vscode.git/clone" did not exist on "d92d844447292e1b2ce8d57e5747b4e0fa233c09"
Commit f893fbf6 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

simplify Dataset class (#163)

* simplify Dataset class

* simplify check output; fix deprecated warning
parent 21ee5947
This diff is collapsed.
...@@ -5,7 +5,7 @@ from __future__ import absolute_import ...@@ -5,7 +5,7 @@ from __future__ import absolute_import
import inspect import inspect
import numpy as np import numpy as np
from .basic import LightGBMError, Dataset from .basic import LightGBMError, Dataset, IS_PY3
from .engine import train from .engine import train
'''sklearn''' '''sklearn'''
try: try:
...@@ -26,6 +26,13 @@ except ImportError: ...@@ -26,6 +26,13 @@ except ImportError:
LGBMLabelEncoder = None LGBMLabelEncoder = None
def _argc(func):
if IS_PY3:
return len(inspect.signature(func).parameters)
else:
return len(inspect.getargspec(func).args)
def _objective_function_wrapper(func): def _objective_function_wrapper(func):
"""Decorate an objective function """Decorate an objective function
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id Note: for multi-class task, the y_pred is group by class_id first, then group by row_id
...@@ -57,7 +64,7 @@ def _objective_function_wrapper(func): ...@@ -57,7 +64,7 @@ def _objective_function_wrapper(func):
def inner(preds, dataset): def inner(preds, dataset):
"""internal function""" """internal function"""
labels = dataset.get_label() labels = dataset.get_label()
argc = len(inspect.getargspec(func).args) argc = _argc(func)
if argc == 2: if argc == 2:
grad, hess = func(labels, preds) grad, hess = func(labels, preds)
elif argc == 3: elif argc == 3:
...@@ -122,7 +129,7 @@ def _eval_function_wrapper(func): ...@@ -122,7 +129,7 @@ def _eval_function_wrapper(func):
def inner(preds, dataset): def inner(preds, dataset):
"""internal function""" """internal function"""
labels = dataset.get_label() labels = dataset.get_label()
argc = len(inspect.getargspec(func).args) argc = _argc(func)
if argc == 2: if argc == 2:
return func(labels, preds) return func(labels, preds)
elif argc == 3: elif argc == 3:
......
...@@ -189,11 +189,12 @@ def test_booster(): ...@@ -189,11 +189,12 @@ def test_booster():
LIB.LGBM_BoosterCreate(train, c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster)) LIB.LGBM_BoosterCreate(train, c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster))
LIB.LGBM_BoosterAddValidData(booster, test) LIB.LGBM_BoosterAddValidData(booster, test)
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
for i in range(100): for i in range(1, 101):
LIB.LGBM_BoosterUpdateOneIter(booster, ctypes.byref(is_finished)) LIB.LGBM_BoosterUpdateOneIter(booster, ctypes.byref(is_finished))
result = np.array([0.0], dtype=np.float64) result = np.array([0.0], dtype=np.float64)
out_len = ctypes.c_ulong(0) out_len = ctypes.c_ulong(0)
LIB.LGBM_BoosterGetEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_double))) LIB.LGBM_BoosterGetEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
if i % 10 == 0:
print('%d Iteration test AUC %f' % (i, result[0])) print('%d Iteration test AUC %f' % (i, result[0]))
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt')) LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster) LIB.LGBM_BoosterFree(booster)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment