Commit f5aa1c9b authored by j-mark-hou's avatar j-mark-hou Committed by Guolin Ke
Browse files

[python package] added get_ref_chain function to Dataset object, used it in train() to… (#745)

* added get_ref_chain function to Dataset object, used it in train() to compare if training and validation have common reference ancestors

* moved check of common ancestor reference from train() in engine.py to set_reference() in Dataset

* moved check of common ancestor reference from train() in engine.py to set_reference() in Dataset

* removed check for handle = None
parent 6c3f2448
......@@ -1004,7 +1004,8 @@ class Dataset(object):
self.set_categorical_feature(reference.categorical_feature)
self.set_feature_name(reference.feature_name)
self._set_predictor(reference._predictor)
if self.reference is reference:
# we're done if self and reference share a common upstrem reference
if self.get_ref_chain().intersection(reference.get_ref_chain()):
return
if self.data is not None:
self.reference = reference
......@@ -1174,6 +1175,28 @@ class Dataset(object):
else:
raise LightGBMError("Cannot get num_feature before construct dataset")
def get_ref_chain(self, ref_limit=100):
'''
Gets a chain of Dataset objects, starting with r, then going to r.reference if exists,
then to r.reference.reference, etc. until we hit ref_limit or a reference loop
Returns
-------
chain of references of self : set of Dataset objects
'''
head = self
ref_chain = []
while len(ref_chain) < ref_limit:
if isinstance(head, Dataset):
ref_chain += [head]
if (head.reference is not None) and (head.reference not in ref_chain):
head = head.reference
else:
break
else:
break
return(set(ref_chain))
class Booster(object):
""""Booster in LightGBM."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment