Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
f059d0fe
Commit
f059d0fe
authored
Nov 26, 2016
by
Guolin Ke
Browse files
fix some bugs in basic.py
parent
6e0b58ba
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
25 deletions
+36
-25
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+36
-25
No files found.
python-package/lightgbm/basic.py
View file @
f059d0fe
...
@@ -78,10 +78,21 @@ def is_numpy_1d_array(data):
...
@@ -78,10 +78,21 @@ def is_numpy_1d_array(data):
else
:
else
:
return
False
return
False
def
is_1d_list
(
data
):
if
not
isinstance
(
data
,
list
):
return
False
if
len
(
data
)
>
0
:
if
not
isinstance
(
data
[
0
],
(
int
,
str
,
bool
)
):
return
False
return
True
def
list_to_1d_numpy
(
data
,
dtype
):
def
list_to_1d_numpy
(
data
,
dtype
):
if
is_numpy_1d_array
(
data
):
if
is_numpy_1d_array
(
data
):
return
data
if
data
.
dtype
==
dtype
:
elif
isinstance
(
data
,
list
):
return
data
else
:
return
data
.
astype
(
dtype
=
dtype
,
copy
=
False
)
elif
is_1d_list
(
data
):
return
np
.
array
(
data
,
dtype
=
dtype
,
copy
=
False
)
return
np
.
array
(
data
,
dtype
=
dtype
,
copy
=
False
)
else
:
else
:
raise
TypeError
(
"Unknow type({})"
.
format
(
type
(
data
).
__name__
))
raise
TypeError
(
"Unknow type({})"
.
format
(
type
(
data
).
__name__
))
...
@@ -140,7 +151,7 @@ FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32,
...
@@ -140,7 +151,7 @@ FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32,
def
c_float_array
(
data
):
def
c_float_array
(
data
):
"""Convert numpy array / list to c float array."""
"""Convert numpy array / list to c float array."""
if
is
instance
(
data
,
list
):
if
is
_1d_list
(
data
):
data
=
np
.
array
(
data
,
copy
=
False
)
data
=
np
.
array
(
data
,
copy
=
False
)
if
is_numpy_1d_array
(
data
):
if
is_numpy_1d_array
(
data
):
if
data
.
dtype
==
np
.
float32
:
if
data
.
dtype
==
np
.
float32
:
...
@@ -157,7 +168,7 @@ def c_float_array(data):
...
@@ -157,7 +168,7 @@ def c_float_array(data):
def
c_int_array
(
data
):
def
c_int_array
(
data
):
"""Convert numpy array to c int array."""
"""Convert numpy array to c int array."""
if
is
instance
(
data
,
list
):
if
is
_1d_list
(
data
):
data
=
np
.
array
(
data
,
copy
=
False
)
data
=
np
.
array
(
data
,
copy
=
False
)
if
is_numpy_1d_array
(
data
):
if
is_numpy_1d_array
(
data
):
if
data
.
dtype
==
np
.
int32
:
if
data
.
dtype
==
np
.
int32
:
...
@@ -256,7 +267,7 @@ class Predictor(object):
...
@@ -256,7 +267,7 @@ class Predictor(object):
else
:
else
:
try
:
try
:
csr
=
scipy
.
sparse
.
csr_matrix
(
data
)
csr
=
scipy
.
sparse
.
csr_matrix
(
data
)
re
s
=
self
.
__pred_for_csr
(
csr
,
num_iteration
,
predict_type
)
p
re
ds
,
nrow
=
self
.
__pred_for_csr
(
csr
,
num_iteration
,
predict_type
)
except
:
except
:
raise
TypeError
(
'can not predict data for type {}'
.
format
(
type
(
data
).
__name__
))
raise
TypeError
(
'can not predict data for type {}'
.
format
(
type
(
data
).
__name__
))
if
pred_leaf
:
if
pred_leaf
:
...
@@ -417,7 +428,7 @@ class Dataset(object):
...
@@ -417,7 +428,7 @@ class Dataset(object):
else
:
else
:
try
:
try
:
csr
=
scipy
.
sparse
.
csr_matrix
(
data
)
csr
=
scipy
.
sparse
.
csr_matrix
(
data
)
self
.
__init_from_csr
(
csr
)
self
.
__init_from_csr
(
csr
,
params_str
,
ref_dataset
)
except
:
except
:
raise
TypeError
(
'can not initialize Dataset from {}'
.
format
(
type
(
data
).
__name__
))
raise
TypeError
(
'can not initialize Dataset from {}'
.
format
(
type
(
data
).
__name__
))
self
.
__label
=
None
self
.
__label
=
None
...
@@ -618,8 +629,6 @@ class Dataset(object):
...
@@ -618,8 +629,6 @@ class Dataset(object):
The label information to be set into Dataset
The label information to be set into Dataset
"""
"""
label
=
list_to_1d_numpy
(
label
,
np
.
float32
)
label
=
list_to_1d_numpy
(
label
,
np
.
float32
)
if
label
.
dtype
!=
np
.
float32
:
label
=
label
.
astype
(
np
.
float32
,
copy
=
False
)
self
.
__label
=
label
self
.
__label
=
label
self
.
set_field
(
'label'
,
label
)
self
.
set_field
(
'label'
,
label
)
...
@@ -633,8 +642,6 @@ class Dataset(object):
...
@@ -633,8 +642,6 @@ class Dataset(object):
"""
"""
if
weight
is
not
None
:
if
weight
is
not
None
:
weight
=
list_to_1d_numpy
(
weight
,
np
.
float32
)
weight
=
list_to_1d_numpy
(
weight
,
np
.
float32
)
if
weight
.
dtype
!=
np
.
float32
:
weight
=
weight
.
astype
(
np
.
float32
,
copy
=
False
)
self
.
__weight
=
weight
self
.
__weight
=
weight
self
.
set_field
(
'weight'
,
weight
)
self
.
set_field
(
'weight'
,
weight
)
...
@@ -647,8 +654,6 @@ class Dataset(object):
...
@@ -647,8 +654,6 @@ class Dataset(object):
"""
"""
if
score
is
not
None
:
if
score
is
not
None
:
score
=
list_to_1d_numpy
(
score
,
np
.
float32
)
score
=
list_to_1d_numpy
(
score
,
np
.
float32
)
if
score
.
dtype
!=
np
.
float32
:
score
=
score
.
astype
(
np
.
float32
,
copy
=
False
)
self
.
__init_score
=
score
self
.
__init_score
=
score
self
.
set_field
(
'init_score'
,
score
)
self
.
set_field
(
'init_score'
,
score
)
...
@@ -662,8 +667,6 @@ class Dataset(object):
...
@@ -662,8 +667,6 @@ class Dataset(object):
"""
"""
if
group
is
not
None
:
if
group
is
not
None
:
group
=
list_to_1d_numpy
(
group
,
np
.
int32
)
group
=
list_to_1d_numpy
(
group
,
np
.
int32
)
if
group
.
dtype
!=
np
.
int32
:
group
=
group
.
astype
(
np
.
int32
,
copy
=
False
)
self
.
__group
=
group
self
.
__group
=
group
self
.
set_field
(
'group'
,
group
)
self
.
set_field
(
'group'
,
group
)
...
@@ -678,8 +681,6 @@ class Dataset(object):
...
@@ -678,8 +681,6 @@ class Dataset(object):
"""
"""
if
group_id
is
not
None
:
if
group_id
is
not
None
:
group_id
=
list_to_1d_numpy
(
group_id
,
np
.
int32
)
group_id
=
list_to_1d_numpy
(
group_id
,
np
.
int32
)
if
group_id
.
dtype
!=
np
.
int32
:
group_id
=
group_id
.
astype
(
np
.
int32
,
copy
=
False
)
self
.
set_field
(
'group_id'
,
group_id
)
self
.
set_field
(
'group_id'
,
group_id
)
def
get_label
(
self
):
def
get_label
(
self
):
...
@@ -890,26 +891,36 @@ class Booster(object):
...
@@ -890,26 +891,36 @@ class Booster(object):
and you should group grad and hess in this way as well
and you should group grad and hess in this way as well
Parameters
Parameters
----------
----------
grad : 1d numpy
with dtype=float32
grad : 1d numpy
or list
The first order of gradient.
The first order of gradient.
hess : 1d numpy
with dtype=float32
hess : 1d numpy
or list
The second order of gradient.
The second order of gradient.
Returns
Returns
-------
-------
is_finished, bool
is_finished, bool
"""
"""
if
not
is_numpy_1d_array
(
grad
)
and
not
is_numpy_1d_array
(
hess
):
if
not
is_numpy_1d_array
(
grad
):
raise
TypeError
(
'type of grad / hess should be 1d numpy object'
)
if
is_1d_list
(
grad
):
if
not
grad
.
dtype
==
np
.
float32
and
not
hess
.
dtype
==
np
.
float32
:
grad
=
np
.
array
(
grad
,
dtype
=
np
.
float32
,
copy
=
False
)
raise
TypeError
(
'type of grad / hess should be np.float32'
)
else
:
raise
TypeError
(
"grad should be numpy 1d array or 1d list"
)
if
not
is_numpy_1d_array
(
hess
):
if
is_1d_list
(
hess
):
hess
=
np
.
array
(
hess
,
dtype
=
np
.
float32
,
copy
=
False
)
else
:
raise
TypeError
(
"hess should be numpy 1d array or 1d list"
)
if
len
(
grad
)
!=
len
(
hess
):
if
len
(
grad
)
!=
len
(
hess
):
raise
ValueError
(
'grad / hess length mismatch: {} / {}'
.
format
(
len
(
grad
),
len
(
hess
)))
raise
ValueError
(
'grad / hess length mismatch: {} / {}'
.
format
(
len
(
grad
),
len
(
hess
)))
if
grad
.
dtype
!=
np
.
float32
:
grad
=
grad
.
astype
(
np
.
float32
,
copy
=
False
)
if
hess
.
dtype
!=
np
.
float32
:
hess
=
hess
.
astype
(
np
.
float32
,
copy
=
False
)
is_finished
=
ctypes
.
c_int
(
0
)
is_finished
=
ctypes
.
c_int
(
0
)
_safe_call
(
_LIB
.
LGBM_BoosterUpdateOneIterCustom
(
_safe_call
(
_LIB
.
LGBM_BoosterUpdateOneIterCustom
(
self
.
handle
,
self
.
handle
,
grad
.
ctypes
.
data_as
(
ctypes
.
ctypes
.
POINTER
(
ctypes
.
c_float
)),
grad
.
ctypes
.
data_as
(
ctypes
.
POINTER
(
ctypes
.
c_float
)),
hess
.
ctypes
.
data_as
(
ctypes
.
ctypes
.
POINTER
(
ctypes
.
c_float
)),
hess
.
ctypes
.
data_as
(
ctypes
.
POINTER
(
ctypes
.
c_float
)),
ctypes
.
byref
(
is_finished
)))
ctypes
.
byref
(
is_finished
)))
return
is_finished
.
value
==
1
return
is_finished
.
value
==
1
...
@@ -950,7 +961,7 @@ class Booster(object):
...
@@ -950,7 +961,7 @@ class Booster(object):
break
break
"""need push new valid data"""
"""need push new valid data"""
if
data_idx
==
-
1
:
if
data_idx
==
-
1
:
self
.
add_valid
_data
(
data
,
name
)
self
.
add_valid
(
data
,
name
)
data_idx
=
self
.
__num_dataset
-
1
data_idx
=
self
.
__num_dataset
-
1
return
self
.
__inner_eval
(
name
,
data_idx
,
feval
)
return
self
.
__inner_eval
(
name
,
data_idx
,
feval
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment