Unverified Commit 11d7608f authored by Xavier Dupré's avatar Xavier Dupré Committed by GitHub
Browse files

[python] add parameter object_hook to method dump_model (#4533)



* add parameter object_hook to function dump_model (python API)

* eol

* fix syntax

* lint

* better documentation

* Update python-package/lightgbm/basic.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarxavier dupré <xavier.dupre@gmail.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 4db10d86
...@@ -3342,7 +3342,7 @@ class Booster: ...@@ -3342,7 +3342,7 @@ class Booster:
ret += _dump_pandas_categorical(self.pandas_categorical) ret += _dump_pandas_categorical(self.pandas_categorical)
return ret return ret
def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split'): def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split', object_hook=None):
"""Dump Booster to JSON format. """Dump Booster to JSON format.
Parameters Parameters
...@@ -3357,6 +3357,15 @@ class Booster: ...@@ -3357,6 +3357,15 @@ class Booster:
What type of feature importance should be dumped. What type of feature importance should be dumped.
If "split", result contains numbers of times the feature is used in a model. If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature. If "gain", result contains total gains of splits which use the feature.
object_hook : callable or None, optional (default=None)
If not None, ``object_hook`` is a function called while parsing the json
string returned by the C API. It may be used to alter the json, to store
specific values while building the json structure. It avoids
walking through the structure again. It saves a significant amount
of time if the number of trees is huge.
Signature is ``def object_hook(node: dict) -> dict``.
None is equivalent to ``lambda node: node``.
See documentation of ``json.loads()`` for further details.
Returns Returns
------- -------
...@@ -3391,7 +3400,7 @@ class Booster: ...@@ -3391,7 +3400,7 @@ class Booster:
ctypes.c_int64(actual_len), ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
ret = json.loads(string_buffer.value.decode('utf-8')) ret = json.loads(string_buffer.value.decode('utf-8'), object_hook=object_hook)
ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical, ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical,
default=json_default_with_numpy)) default=json_default_with_numpy))
return ret return ret
......
...@@ -2846,3 +2846,23 @@ def test_dump_model(): ...@@ -2846,3 +2846,23 @@ def test_dump_model():
assert "leaf_const" in dumped_model_str assert "leaf_const" in dumped_model_str
assert "leaf_value" in dumped_model_str assert "leaf_value" in dumped_model_str
assert "leaf_count" in dumped_model_str assert "leaf_count" in dumped_model_str
def test_dump_model_hook():
def hook(obj):
if 'leaf_value' in obj:
obj['LV'] = obj['leaf_value']
del obj['leaf_value']
return obj
X, y = load_breast_cancer(return_X_y=True)
train_data = lgb.Dataset(X, label=y)
params = {
"objective": "binary",
"verbose": -1
}
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model_str = str(bst.dump_model(5, 0, object_hook=hook))
assert "leaf_value" not in dumped_model_str
assert "LV" in dumped_model_str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment