Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
fa4ecfda
"src/vscode:/vscode.git/clone" did not exist on "d038aa5716a3e1db0ce717eeef469df366b7aade"
Commit
fa4ecfda
authored
Nov 22, 2016
by
Guolin Ke
Browse files
add constructor for booster
parent
de114be5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
111 additions
and
18 deletions
+111
-18
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+111
-18
No files found.
python-package/lightgbm/basic.py
View file @
fa4ecfda
...
@@ -110,6 +110,13 @@ def c_array(ctype, values):
...
@@ -110,6 +110,13 @@ def c_array(ctype, values):
"""Convert a python array to c array."""
"""Convert a python array to c array."""
return
(
ctype
*
len
(
values
))(
*
values
)
return
(
ctype
*
len
(
values
))(
*
values
)
def
dict_to_str
(
data
):
if
len
(
data
)
==
0
:
return
""
pairs
=
[]
for
key
in
data
:
pairs
.
append
(
str
(
key
)
+
'='
+
str
(
data
[
key
]))
return
' '
.
join
(
pairs
)
"""marco definition of data type in c_api of LightGBM"""
"""marco definition of data type in c_api of LightGBM"""
C_API_DTYPE_FLOAT32
=
0
C_API_DTYPE_FLOAT32
=
0
C_API_DTYPE_FLOAT64
=
1
C_API_DTYPE_FLOAT64
=
1
...
@@ -164,7 +171,7 @@ class Dataset(object):
...
@@ -164,7 +171,7 @@ class Dataset(object):
def
__init__
(
self
,
data
,
max_bin
=
255
,
reference
=
None
,
def
__init__
(
self
,
data
,
max_bin
=
255
,
reference
=
None
,
label
=
None
,
weight
=
None
,
group_id
=
None
,
label
=
None
,
weight
=
None
,
group_id
=
None
,
silent
=
False
,
feature_names
=
None
,
silent
=
False
,
feature_names
=
None
,
other_ar
g
s
=
None
):
other_
p
ar
am
s
=
None
,
is_continue_train
=
False
):
"""
"""
Dataset used in LightGBM.
Dataset used in LightGBM.
...
@@ -187,20 +194,27 @@ class Dataset(object):
...
@@ -187,20 +194,27 @@ class Dataset(object):
Whether print messages during construction
Whether print messages during construction
feature_names : list, optional
feature_names : list, optional
Set names for features.
Set names for features.
other_ar
g
s:
lis
t, optional
other_
p
ar
am
s:
dic
t, optional
other parameters
, format: ['key1=val1','key2=val2']
other parameters
"""
"""
if
data
is
None
:
if
data
is
None
:
self
.
handle
=
None
self
.
handle
=
None
return
return
"""save raw data for continue train """
if
is_continue_train
:
self
.
raw_data
=
data
else
:
self
.
raw_data
=
None
"""process for args"""
"""process for args"""
pass_args
=
[
"max_bin={}"
.
format
(
max_bin
)]
params
=
{}
params
[
"max_bin"
]
=
max_bin
if
silent
:
if
silent
:
pass_args
.
append
(
"verbose=0"
)
params
[
"verbose"
]
=
0
if
other_args
:
if
other_params
:
pass_args
+=
other_args
other_params
.
update
(
params
)
pass_args_str
=
' '
.
join
(
pass_args
)
params
=
other_params
params_str
=
dict_to_str
(
params
)
"""process for reference dataset"""
"""process for reference dataset"""
ref_dataset
=
None
ref_dataset
=
None
if
isinstance
(
reference
,
Dataset
):
if
isinstance
(
reference
,
Dataset
):
...
@@ -212,15 +226,15 @@ class Dataset(object):
...
@@ -212,15 +226,15 @@ class Dataset(object):
self
.
handle
=
ctypes
.
c_void_p
()
self
.
handle
=
ctypes
.
c_void_p
()
_safe_call
(
_LIB
.
LGBM_CreateDatasetFromFile
(
_safe_call
(
_LIB
.
LGBM_CreateDatasetFromFile
(
c_str
(
data
),
c_str
(
data
),
c_str
(
pa
ss_arg
s_str
),
c_str
(
pa
ram
s_str
),
ref_dataset
,
ref_dataset
,
ctypes
.
byref
(
self
.
handle
)))
ctypes
.
byref
(
self
.
handle
)))
elif
isinstance
(
data
,
scipy
.
sparse
.
csr_matrix
):
elif
isinstance
(
data
,
scipy
.
sparse
.
csr_matrix
):
self
.
_init_from_csr
(
data
,
pa
ss_arg
s_str
,
ref_dataset
)
self
.
_init_from_csr
(
data
,
pa
ram
s_str
,
ref_dataset
)
elif
isinstance
(
data
,
scipy
.
sparse
.
csc_matrix
):
elif
isinstance
(
data
,
scipy
.
sparse
.
csc_matrix
):
self
.
_init_from_csc
(
data
,
pa
ss_arg
s_str
,
ref_dataset
)
self
.
_init_from_csc
(
data
,
pa
ram
s_str
,
ref_dataset
)
elif
isinstance
(
data
,
np
.
ndarray
):
elif
isinstance
(
data
,
np
.
ndarray
):
self
.
_init_from_npy2d
(
data
,
pa
ss_arg
s_str
,
ref_dataset
)
self
.
_init_from_npy2d
(
data
,
pa
ram
s_str
,
ref_dataset
)
else
:
else
:
try
:
try
:
csr
=
scipy
.
sparse
.
csr_matrix
(
data
)
csr
=
scipy
.
sparse
.
csr_matrix
(
data
)
...
@@ -235,7 +249,10 @@ class Dataset(object):
...
@@ -235,7 +249,10 @@ class Dataset(object):
self
.
set_group_id
(
group_id
)
self
.
set_group_id
(
group_id
)
self
.
feature_names
=
feature_names
self
.
feature_names
=
feature_names
def
_init_from_csr
(
self
,
csr
,
pass_args_str
,
ref_dataset
):
def
free_raw_data
(
self
):
self
.
raw_data
=
None
def
_init_from_csr
(
self
,
csr
,
params_str
,
ref_dataset
):
"""
"""
Initialize data from a CSR matrix.
Initialize data from a CSR matrix.
"""
"""
...
@@ -255,11 +272,11 @@ class Dataset(object):
...
@@ -255,11 +272,11 @@ class Dataset(object):
len
(
csr
.
indptr
),
len
(
csr
.
indptr
),
len
(
csr
.
data
),
len
(
csr
.
data
),
csr
.
shape
[
1
],
csr
.
shape
[
1
],
c_str
(
pa
ss_arg
s_str
),
c_str
(
pa
ram
s_str
),
ref_dataset
,
ref_dataset
,
ctypes
.
byref
(
self
.
handle
)))
ctypes
.
byref
(
self
.
handle
)))
def
_init_from_csc
(
self
,
csr
,
pa
ss_arg
s_str
,
ref_dataset
):
def
_init_from_csc
(
self
,
csr
,
pa
ram
s_str
,
ref_dataset
):
"""
"""
Initialize data from a CSC matrix.
Initialize data from a CSC matrix.
"""
"""
...
@@ -279,11 +296,11 @@ class Dataset(object):
...
@@ -279,11 +296,11 @@ class Dataset(object):
len
(
csc
.
indptr
),
len
(
csc
.
indptr
),
len
(
csc
.
data
),
len
(
csc
.
data
),
csc
.
shape
[
0
],
csc
.
shape
[
0
],
c_str
(
pa
ss_arg
s_str
),
c_str
(
pa
ram
s_str
),
ref_dataset
,
ref_dataset
,
ctypes
.
byref
(
self
.
handle
)))
ctypes
.
byref
(
self
.
handle
)))
def
_init_from_npy2d
(
self
,
mat
,
pa
ss_arg
s_str
,
ref_dataset
):
def
_init_from_npy2d
(
self
,
mat
,
pa
ram
s_str
,
ref_dataset
):
"""
"""
Initialize data from a 2-D numpy matrix.
Initialize data from a 2-D numpy matrix.
"""
"""
...
@@ -304,7 +321,7 @@ class Dataset(object):
...
@@ -304,7 +321,7 @@ class Dataset(object):
mat
.
shape
[
0
],
mat
.
shape
[
0
],
mat
.
shape
[
1
],
mat
.
shape
[
1
],
C_API_IS_ROW_MAJOR
,
C_API_IS_ROW_MAJOR
,
c_str
(
pa
ss_arg
s_str
),
c_str
(
pa
ram
s_str
),
ref_dataset
,
ref_dataset
,
ctypes
.
byref
(
self
.
handle
)))
ctypes
.
byref
(
self
.
handle
)))
...
@@ -536,3 +553,79 @@ class Dataset(object):
...
@@ -536,3 +553,79 @@ class Dataset(object):
else
:
else
:
self
.
_feature_names
=
None
self
.
_feature_names
=
None
class
Booster
(
object
):
""""A Booster of of LightGBM.
"""
feature_names
=
None
def
__init__
(
self
,
params
=
None
,
train_set
=
None
,
valid_sets
=
None
,
name_valid_sets
=
None
,
model_file
=
None
,
fobj
=
None
):
# pylint: disable=invalid-name
"""Initialize the Booster.
Parameters
----------
params : dict
Parameters for boosters.
train_set : Dataset
training dataset
valid_sets : List of Dataset or None
validation datasets
name_valid_sets : List of string
name of validation datasets
model_file : string
Path to the model file.
"""
self
.
handle
=
ctypes
.
c_void_p
()
if
train_set
is
not
None
:
if
not
isinstance
(
train_set
,
Dataset
):
raise
TypeError
(
'training data should be Dataset instance, met{}'
.
format
(
type
(
train_set
).
__name__
))
valid_handles
=
None
valid_cnames
=
None
n_valid
=
0
if
valid_sets
is
not
None
:
for
valid
in
valid_sets
:
if
not
isinstance
(
valid
,
Dataset
):
raise
TypeError
(
'valid data should be Dataset instance, met{}'
.
format
(
type
(
valid
).
__name__
))
valid_handles
=
c_array
(
ctypes
.
c_void_p
,
[
valid
.
handle
for
valid
in
valid_sets
])
if
name_valid_sets
is
None
:
name_valid_sets
=
[
"valid_{}"
.
format
(
x
)
for
x
in
range
(
len
(
valid_sets
))
]
if
len
(
valid_sets
)
!=
len
(
name_valid_sets
):
raise
Exception
(
'len of valid_sets should be equal with len of name_valid_sets'
)
valid_cnames
=
c_array
(
ctypes
.
c_char_p
,
[
c_str
(
x
)
for
x
in
name_valid_sets
])
n_valid
=
len
(
valid_sets
)
ref_input_model
=
None
params_str
=
dict_to_str
(
params
)
if
model_file
is
not
None
:
ref_input_model
=
c_str
(
model_file
)
"""construct booster object"""
_safe_call
(
LIB
.
LGBM_BoosterCreate
(
train_set
.
handle
,
valid_handles
,
valid_cnames
,
n_valid
,
params_str
,
ref_input_model
,
ctypes
.
byref
(
self
.
handle
)))
"""if need to continue train"""
if
model_file
is
not
None
:
self
.
init_continue_train
(
train_set
)
if
valid_sets
is
not
None
:
for
valid
in
valid_sets
:
self
.
init_continue_train
(
valid
)
elif
model_file
is
not
None
:
_safe_call
(
_LIB
.
LGBM_BoosterCreateFromModelfile
(
c_str
(
model_file
),
ctypes
.
byref
(
self
.
handle
)))
else
:
raise
TypeError
(
'At least need training dataset or model file to create booster instance'
)
def
__del__
(
self
):
_LIB
.
LGBM_BoosterFree
(
self
.
handle
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment