Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
619c06d8
Unverified
Commit
619c06d8
authored
Nov 24, 2017
by
Guolin Ke
Committed by
GitHub
Nov 24, 2017
Browse files
fix save_model_to_string for large model (#1080)
* use int64 for string * [R] Fatal when CSC exceed int32.max
parent
53739670
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
24 deletions
+34
-24
R-package/R/lgb.Dataset.R
R-package/R/lgb.Dataset.R
+3
-1
R-package/R/lgb.Predictor.R
R-package/R/lgb.Predictor.R
+3
-1
include/LightGBM/c_api.h
include/LightGBM/c_api.h
+4
-4
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+6
-6
src/c_api.cpp
src/c_api.cpp
+6
-6
src/lightgbm_R.cpp
src/lightgbm_R.cpp
+12
-6
No files found.
R-package/R/lgb.Dataset.R
View file @
619c06d8
...
@@ -193,7 +193,9 @@ Dataset <- R6Class(
...
@@ -193,7 +193,9 @@ Dataset <- R6Class(
ref_handle
)
ref_handle
)
}
else
if
(
is
(
private
$
raw_data
,
"dgCMatrix"
))
{
}
else
if
(
is
(
private
$
raw_data
,
"dgCMatrix"
))
{
if
(
length
(
private
$
raw_data
@
p
)
>
2147483647
)
{
stop
(
"Cannot support large CSC matrix"
)
}
# Are we using a dgCMatrix (sparsed matrix column compressed)
# Are we using a dgCMatrix (sparsed matrix column compressed)
handle
<-
lgb.call
(
"LGBM_DatasetCreateFromCSC_R"
,
handle
<-
lgb.call
(
"LGBM_DatasetCreateFromCSC_R"
,
ret
=
handle
,
ret
=
handle
,
...
...
R-package/R/lgb.Predictor.R
View file @
619c06d8
...
@@ -127,7 +127,9 @@ Predictor <- R6Class(
...
@@ -127,7 +127,9 @@ Predictor <- R6Class(
private
$
params
)
private
$
params
)
}
else
if
(
is
(
data
,
"dgCMatrix"
))
{
}
else
if
(
is
(
data
,
"dgCMatrix"
))
{
if
(
length
(
data
@
p
)
>
2147483647
)
{
stop
(
"Cannot support large CSC matrix"
)
}
# Check if data is a dgCMatrix (sparse matrix, column compressed format)
# Check if data is a dgCMatrix (sparse matrix, column compressed format)
preds
<-
lgb.call
(
"LGBM_BoosterPredictForCSC_R"
,
preds
<-
lgb.call
(
"LGBM_BoosterPredictForCSC_R"
,
ret
=
preds
,
ret
=
preds
,
...
...
include/LightGBM/c_api.h
View file @
619c06d8
...
@@ -677,8 +677,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
...
@@ -677,8 +677,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
*/
*/
LIGHTGBM_C_EXPORT
int
LGBM_BoosterSaveModelToString
(
BoosterHandle
handle
,
LIGHTGBM_C_EXPORT
int
LGBM_BoosterSaveModelToString
(
BoosterHandle
handle
,
int
num_iteration
,
int
num_iteration
,
int
buffer_len
,
int
64_t
buffer_len
,
int
*
out_len
,
int
64_t
*
out_len
,
char
*
out_str
);
char
*
out_str
);
/*!
/*!
...
@@ -692,8 +692,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
...
@@ -692,8 +692,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
*/
*/
LIGHTGBM_C_EXPORT
int
LGBM_BoosterDumpModel
(
BoosterHandle
handle
,
LIGHTGBM_C_EXPORT
int
LGBM_BoosterDumpModel
(
BoosterHandle
handle
,
int
num_iteration
,
int
num_iteration
,
int
buffer_len
,
int
64_t
buffer_len
,
int
*
out_len
,
int
64_t
*
out_len
,
char
*
out_str
);
char
*
out_str
);
/*!
/*!
...
...
python-package/lightgbm/basic.py
View file @
619c06d8
...
@@ -1668,13 +1668,13 @@ class Booster(object):
...
@@ -1668,13 +1668,13 @@ class Booster(object):
if
num_iteration
<=
0
:
if
num_iteration
<=
0
:
num_iteration
=
self
.
best_iteration
num_iteration
=
self
.
best_iteration
buffer_len
=
1
<<
20
buffer_len
=
1
<<
20
tmp_out_len
=
ctypes
.
c_int
(
0
)
tmp_out_len
=
ctypes
.
c_int
64
(
0
)
string_buffer
=
ctypes
.
create_string_buffer
(
buffer_len
)
string_buffer
=
ctypes
.
create_string_buffer
(
buffer_len
)
ptr_string_buffer
=
ctypes
.
c_char_p
(
*
[
ctypes
.
addressof
(
string_buffer
)])
ptr_string_buffer
=
ctypes
.
c_char_p
(
*
[
ctypes
.
addressof
(
string_buffer
)])
_safe_call
(
_LIB
.
LGBM_BoosterSaveModelToString
(
_safe_call
(
_LIB
.
LGBM_BoosterSaveModelToString
(
self
.
handle
,
self
.
handle
,
ctypes
.
c_int
(
num_iteration
),
ctypes
.
c_int
(
num_iteration
),
ctypes
.
c_int
(
buffer_len
),
ctypes
.
c_int
64
(
buffer_len
),
ctypes
.
byref
(
tmp_out_len
),
ctypes
.
byref
(
tmp_out_len
),
ptr_string_buffer
))
ptr_string_buffer
))
actual_len
=
tmp_out_len
.
value
actual_len
=
tmp_out_len
.
value
...
@@ -1685,7 +1685,7 @@ class Booster(object):
...
@@ -1685,7 +1685,7 @@ class Booster(object):
_safe_call
(
_LIB
.
LGBM_BoosterSaveModelToString
(
_safe_call
(
_LIB
.
LGBM_BoosterSaveModelToString
(
self
.
handle
,
self
.
handle
,
ctypes
.
c_int
(
num_iteration
),
ctypes
.
c_int
(
num_iteration
),
ctypes
.
c_int
(
actual_len
),
ctypes
.
c_int
64
(
actual_len
),
ctypes
.
byref
(
tmp_out_len
),
ctypes
.
byref
(
tmp_out_len
),
ptr_string_buffer
))
ptr_string_buffer
))
return
string_buffer
.
value
.
decode
()
return
string_buffer
.
value
.
decode
()
...
@@ -1707,13 +1707,13 @@ class Booster(object):
...
@@ -1707,13 +1707,13 @@ class Booster(object):
if
num_iteration
<=
0
:
if
num_iteration
<=
0
:
num_iteration
=
self
.
best_iteration
num_iteration
=
self
.
best_iteration
buffer_len
=
1
<<
20
buffer_len
=
1
<<
20
tmp_out_len
=
ctypes
.
c_int
(
0
)
tmp_out_len
=
ctypes
.
c_int
64
(
0
)
string_buffer
=
ctypes
.
create_string_buffer
(
buffer_len
)
string_buffer
=
ctypes
.
create_string_buffer
(
buffer_len
)
ptr_string_buffer
=
ctypes
.
c_char_p
(
*
[
ctypes
.
addressof
(
string_buffer
)])
ptr_string_buffer
=
ctypes
.
c_char_p
(
*
[
ctypes
.
addressof
(
string_buffer
)])
_safe_call
(
_LIB
.
LGBM_BoosterDumpModel
(
_safe_call
(
_LIB
.
LGBM_BoosterDumpModel
(
self
.
handle
,
self
.
handle
,
ctypes
.
c_int
(
num_iteration
),
ctypes
.
c_int
(
num_iteration
),
ctypes
.
c_int
(
buffer_len
),
ctypes
.
c_int
64
(
buffer_len
),
ctypes
.
byref
(
tmp_out_len
),
ctypes
.
byref
(
tmp_out_len
),
ptr_string_buffer
))
ptr_string_buffer
))
actual_len
=
tmp_out_len
.
value
actual_len
=
tmp_out_len
.
value
...
@@ -1724,7 +1724,7 @@ class Booster(object):
...
@@ -1724,7 +1724,7 @@ class Booster(object):
_safe_call
(
_LIB
.
LGBM_BoosterDumpModel
(
_safe_call
(
_LIB
.
LGBM_BoosterDumpModel
(
self
.
handle
,
self
.
handle
,
ctypes
.
c_int
(
num_iteration
),
ctypes
.
c_int
(
num_iteration
),
ctypes
.
c_int
(
actual_len
),
ctypes
.
c_int
64
(
actual_len
),
ctypes
.
byref
(
tmp_out_len
),
ctypes
.
byref
(
tmp_out_len
),
ptr_string_buffer
))
ptr_string_buffer
))
return
json
.
loads
(
string_buffer
.
value
.
decode
())
return
json
.
loads
(
string_buffer
.
value
.
decode
())
...
...
src/c_api.cpp
View file @
619c06d8
...
@@ -1140,13 +1140,13 @@ int LGBM_BoosterSaveModel(BoosterHandle handle,
...
@@ -1140,13 +1140,13 @@ int LGBM_BoosterSaveModel(BoosterHandle handle,
#pragma warning(disable : 4996)
#pragma warning(disable : 4996)
int
LGBM_BoosterSaveModelToString
(
BoosterHandle
handle
,
int
LGBM_BoosterSaveModelToString
(
BoosterHandle
handle
,
int
num_iteration
,
int
num_iteration
,
int
buffer_len
,
int
64_t
buffer_len
,
int
*
out_len
,
int
64_t
*
out_len
,
char
*
out_str
)
{
char
*
out_str
)
{
API_BEGIN
();
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
std
::
string
model
=
ref_booster
->
SaveModelToString
(
num_iteration
);
std
::
string
model
=
ref_booster
->
SaveModelToString
(
num_iteration
);
*
out_len
=
static_cast
<
int
>
(
model
.
size
())
+
1
;
*
out_len
=
static_cast
<
int
64_t
>
(
model
.
size
())
+
1
;
if
(
*
out_len
<=
buffer_len
)
{
if
(
*
out_len
<=
buffer_len
)
{
std
::
strcpy
(
out_str
,
model
.
c_str
());
std
::
strcpy
(
out_str
,
model
.
c_str
());
}
}
...
@@ -1156,13 +1156,13 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
...
@@ -1156,13 +1156,13 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
#pragma warning(disable : 4996)
#pragma warning(disable : 4996)
int
LGBM_BoosterDumpModel
(
BoosterHandle
handle
,
int
LGBM_BoosterDumpModel
(
BoosterHandle
handle
,
int
num_iteration
,
int
num_iteration
,
int
buffer_len
,
int
64_t
buffer_len
,
int
*
out_len
,
int
64_t
*
out_len
,
char
*
out_str
)
{
char
*
out_str
)
{
API_BEGIN
();
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
std
::
string
model
=
ref_booster
->
DumpModel
(
num_iteration
);
std
::
string
model
=
ref_booster
->
DumpModel
(
num_iteration
);
*
out_len
=
static_cast
<
int
>
(
model
.
size
())
+
1
;
*
out_len
=
static_cast
<
int
64_t
>
(
model
.
size
())
+
1
;
if
(
*
out_len
<=
buffer_len
)
{
if
(
*
out_len
<=
buffer_len
)
{
std
::
strcpy
(
out_str
,
model
.
c_str
());
std
::
strcpy
(
out_str
,
model
.
c_str
());
}
}
...
...
src/lightgbm_R.cpp
View file @
619c06d8
...
@@ -601,14 +601,17 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
...
@@ -601,14 +601,17 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE
out_str
,
LGBM_SE
out_str
,
LGBM_SE
call_state
)
{
LGBM_SE
call_state
)
{
R_API_BEGIN
();
R_API_BEGIN
();
int
out_len
=
0
;
int
64_t
out_len
=
0
;
std
::
vector
<
char
>
inner_char_buf
(
R_AS_INT
(
buffer_len
));
std
::
vector
<
char
>
inner_char_buf
(
R_AS_INT
(
buffer_len
));
CHECK_CALL
(
LGBM_BoosterSaveModelToString
(
R_GET_PTR
(
handle
),
R_AS_INT
(
num_iteration
),
R_AS_INT
(
buffer_len
),
&
out_len
,
inner_char_buf
.
data
()));
CHECK_CALL
(
LGBM_BoosterSaveModelToString
(
R_GET_PTR
(
handle
),
R_AS_INT
(
num_iteration
),
R_AS_INT
(
buffer_len
),
&
out_len
,
inner_char_buf
.
data
()));
EncodeChar
(
out_str
,
inner_char_buf
.
data
(),
buffer_len
,
actual_len
);
if
(
out_len
<
R_AS_INT
(
buffer_len
))
{
if
(
out_len
<
R_AS_INT
(
buffer_len
))
{
EncodeChar
(
out_str
,
inner_char_buf
.
data
(),
buffer_len
,
actual_len
);
EncodeChar
(
out_str
,
inner_char_buf
.
data
(),
buffer_len
,
actual_len
);
}
else
{
}
else
{
R_INT_PTR
(
actual_len
)[
0
]
=
static_cast
<
int
>
(
out_len
);
if
(
out_len
<=
INT32_MAX
)
{
R_INT_PTR
(
actual_len
)[
0
]
=
static_cast
<
int
>
(
out_len
);
}
else
{
Log
::
Fatal
(
"Don't support large model in R package."
);
}
}
}
R_API_END
();
R_API_END
();
}
}
...
@@ -620,14 +623,17 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
...
@@ -620,14 +623,17 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
LGBM_SE
out_str
,
LGBM_SE
out_str
,
LGBM_SE
call_state
)
{
LGBM_SE
call_state
)
{
R_API_BEGIN
();
R_API_BEGIN
();
int
out_len
=
0
;
int
64_t
out_len
=
0
;
std
::
vector
<
char
>
inner_char_buf
(
R_AS_INT
(
buffer_len
));
std
::
vector
<
char
>
inner_char_buf
(
R_AS_INT
(
buffer_len
));
CHECK_CALL
(
LGBM_BoosterDumpModel
(
R_GET_PTR
(
handle
),
R_AS_INT
(
num_iteration
),
R_AS_INT
(
buffer_len
),
&
out_len
,
inner_char_buf
.
data
()));
CHECK_CALL
(
LGBM_BoosterDumpModel
(
R_GET_PTR
(
handle
),
R_AS_INT
(
num_iteration
),
R_AS_INT
(
buffer_len
),
&
out_len
,
inner_char_buf
.
data
()));
EncodeChar
(
out_str
,
inner_char_buf
.
data
(),
buffer_len
,
actual_len
);
if
(
out_len
<
R_AS_INT
(
buffer_len
))
{
if
(
out_len
<
R_AS_INT
(
buffer_len
))
{
EncodeChar
(
out_str
,
inner_char_buf
.
data
(),
buffer_len
,
actual_len
);
EncodeChar
(
out_str
,
inner_char_buf
.
data
(),
buffer_len
,
actual_len
);
}
else
{
}
else
{
R_INT_PTR
(
actual_len
)[
0
]
=
static_cast
<
int
>
(
out_len
);
if
(
out_len
<=
INT32_MAX
)
{
R_INT_PTR
(
actual_len
)[
0
]
=
static_cast
<
int
>
(
out_len
);
}
else
{
Log
::
Fatal
(
"Don't support large model in R package."
);
}
}
}
R_API_END
();
R_API_END
();
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment