Unverified Commit 216eaff7 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type annotations on Booster.trees_to_dataframe() inner functions (#5811)

parent 5f79626f
...@@ -3289,13 +3289,21 @@ class Booster: ...@@ -3289,13 +3289,21 @@ class Booster:
if self.num_trees() == 0: if self.num_trees() == 0:
raise LightGBMError('There are no trees in this Booster and thus nothing to parse') raise LightGBMError('There are no trees in this Booster and thus nothing to parse')
def _is_split_node(tree): def _is_split_node(tree: Dict[str, Any]) -> bool:
return 'split_index' in tree.keys() return 'split_index' in tree.keys()
def create_node_record(tree, node_depth=1, tree_index=None, def create_node_record(
feature_names=None, parent_node=None): tree: Dict[str, Any],
node_depth: int = 1,
tree_index: Optional[int] = None,
feature_names: Optional[List[str]] = None,
parent_node: Optional[str] = None
) -> Dict[str, Any]:
def _get_node_index(tree, tree_index): def _get_node_index(
tree: Dict[str, Any],
tree_index: Optional[int]
) -> str:
tree_num = f'{tree_index}-' if tree_index is not None else '' tree_num = f'{tree_index}-' if tree_index is not None else ''
is_split = _is_split_node(tree) is_split = _is_split_node(tree)
node_type = 'S' if is_split else 'L' node_type = 'S' if is_split else 'L'
...@@ -3303,7 +3311,10 @@ class Booster: ...@@ -3303,7 +3311,10 @@ class Booster:
node_num = tree.get('split_index' if is_split else 'leaf_index', 0) node_num = tree.get('split_index' if is_split else 'leaf_index', 0)
return f"{tree_num}{node_type}{node_num}" return f"{tree_num}{node_type}{node_num}"
def _get_split_feature(tree, feature_names): def _get_split_feature(
tree: Dict[str, Any],
feature_names: Optional[List[str]]
) -> Optional[str]:
if _is_split_node(tree): if _is_split_node(tree):
if feature_names is not None: if feature_names is not None:
feature_name = feature_names[tree['split_feature']] feature_name = feature_names[tree['split_feature']]
...@@ -3313,11 +3324,11 @@ class Booster: ...@@ -3313,11 +3324,11 @@ class Booster:
feature_name = None feature_name = None
return feature_name return feature_name
def _is_single_node_tree(tree): def _is_single_node_tree(tree: Dict[str, Any]) -> bool:
return set(tree.keys()) == {'leaf_value'} return set(tree.keys()) == {'leaf_value'}
# Create the node record, and populate universal data members # Create the node record, and populate universal data members
node = OrderedDict() node: Dict[str, Union[int, str, None]] = OrderedDict()
node['tree_index'] = tree_index node['tree_index'] = tree_index
node['node_depth'] = node_depth node['node_depth'] = node_depth
node['node_index'] = _get_node_index(tree, tree_index) node['node_index'] = _get_node_index(tree, tree_index)
...@@ -3354,10 +3365,15 @@ class Booster: ...@@ -3354,10 +3365,15 @@ class Booster:
return node return node
def tree_dict_to_node_list(tree, node_depth=1, tree_index=None, def tree_dict_to_node_list(
feature_names=None, parent_node=None): tree: Dict[str, Any],
node_depth: int = 1,
tree_index: Optional[int] = None,
feature_names: Optional[List[str]] = None,
parent_node: Optional[str] = None
) -> List[Dict[str, Any]]:
node = create_node_record(tree, node = create_node_record(tree=tree,
node_depth=node_depth, node_depth=node_depth,
tree_index=tree_index, tree_index=tree_index,
feature_names=feature_names, feature_names=feature_names,
...@@ -3370,11 +3386,12 @@ class Booster: ...@@ -3370,11 +3386,12 @@ class Booster:
children = ['left_child', 'right_child'] children = ['left_child', 'right_child']
for child in children: for child in children:
subtree_list = tree_dict_to_node_list( subtree_list = tree_dict_to_node_list(
tree[child], tree=tree[child],
node_depth=node_depth + 1, node_depth=node_depth + 1,
tree_index=tree_index, tree_index=tree_index,
feature_names=feature_names, feature_names=feature_names,
parent_node=node['node_index']) parent_node=node['node_index']
)
# In tree format, "subtree_list" is a list of node records (dicts), # In tree format, "subtree_list" is a list of node records (dicts),
# and we add node to the list. # and we add node to the list.
res.extend(subtree_list) res.extend(subtree_list)
...@@ -3384,7 +3401,7 @@ class Booster: ...@@ -3384,7 +3401,7 @@ class Booster:
feature_names = model_dict['feature_names'] feature_names = model_dict['feature_names']
model_list = [] model_list = []
for tree in model_dict['tree_info']: for tree in model_dict['tree_info']:
model_list.extend(tree_dict_to_node_list(tree['tree_structure'], model_list.extend(tree_dict_to_node_list(tree=tree['tree_structure'],
tree_index=tree['tree_index'], tree_index=tree['tree_index'],
feature_names=feature_names)) feature_names=feature_names))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment