Unverified Commit 216eaff7 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type annotations on Booster.trees_to_dataframe() inner functions (#5811)

parent 5f79626f
......@@ -3289,13 +3289,21 @@ class Booster:
if self.num_trees() == 0:
raise LightGBMError('There are no trees in this Booster and thus nothing to parse')
def _is_split_node(tree):
def _is_split_node(tree: Dict[str, Any]) -> bool:
return 'split_index' in tree.keys()
def create_node_record(tree, node_depth=1, tree_index=None,
feature_names=None, parent_node=None):
def _get_node_index(tree, tree_index):
def create_node_record(
tree: Dict[str, Any],
node_depth: int = 1,
tree_index: Optional[int] = None,
feature_names: Optional[List[str]] = None,
parent_node: Optional[str] = None
) -> Dict[str, Any]:
def _get_node_index(
tree: Dict[str, Any],
tree_index: Optional[int]
) -> str:
tree_num = f'{tree_index}-' if tree_index is not None else ''
is_split = _is_split_node(tree)
node_type = 'S' if is_split else 'L'
......@@ -3303,7 +3311,10 @@ class Booster:
node_num = tree.get('split_index' if is_split else 'leaf_index', 0)
return f"{tree_num}{node_type}{node_num}"
def _get_split_feature(tree, feature_names):
def _get_split_feature(
tree: Dict[str, Any],
feature_names: Optional[List[str]]
) -> Optional[str]:
if _is_split_node(tree):
if feature_names is not None:
feature_name = feature_names[tree['split_feature']]
......@@ -3313,11 +3324,11 @@ class Booster:
feature_name = None
return feature_name
def _is_single_node_tree(tree):
def _is_single_node_tree(tree: Dict[str, Any]) -> bool:
return set(tree.keys()) == {'leaf_value'}
# Create the node record, and populate universal data members
node = OrderedDict()
node: Dict[str, Union[int, str, None]] = OrderedDict()
node['tree_index'] = tree_index
node['node_depth'] = node_depth
node['node_index'] = _get_node_index(tree, tree_index)
......@@ -3354,10 +3365,15 @@ class Booster:
return node
def tree_dict_to_node_list(tree, node_depth=1, tree_index=None,
feature_names=None, parent_node=None):
def tree_dict_to_node_list(
tree: Dict[str, Any],
node_depth: int = 1,
tree_index: Optional[int] = None,
feature_names: Optional[List[str]] = None,
parent_node: Optional[str] = None
) -> List[Dict[str, Any]]:
node = create_node_record(tree,
node = create_node_record(tree=tree,
node_depth=node_depth,
tree_index=tree_index,
feature_names=feature_names,
......@@ -3370,11 +3386,12 @@ class Booster:
children = ['left_child', 'right_child']
for child in children:
subtree_list = tree_dict_to_node_list(
tree[child],
tree=tree[child],
node_depth=node_depth + 1,
tree_index=tree_index,
feature_names=feature_names,
parent_node=node['node_index'])
parent_node=node['node_index']
)
# In tree format, "subtree_list" is a list of node records (dicts),
# and we add node to the list.
res.extend(subtree_list)
......@@ -3384,7 +3401,7 @@ class Booster:
feature_names = model_dict['feature_names']
model_list = []
for tree in model_dict['tree_info']:
model_list.extend(tree_dict_to_node_list(tree['tree_structure'],
model_list.extend(tree_dict_to_node_list(tree=tree['tree_structure'],
tree_index=tree['tree_index'],
feature_names=feature_names))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment