Unverified Commit 18a300aa authored by Rakshit P's avatar Rakshit P Committed by GitHub
Browse files

fix mypy error in engine.py (#4675)

parent b1b6db4b
...@@ -188,12 +188,11 @@ def train( ...@@ -188,12 +188,11 @@ def train(
if num_boost_round <= 0: if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.") raise ValueError("num_boost_round should be greater than zero.")
predictor: Optional[_InnerPredictor] = None
if isinstance(init_model, (str, Path)): if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params) predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
elif isinstance(init_model, Booster): elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(dict(init_model.params, **params)) predictor = init_model._to_predictor(dict(init_model.params, **params))
else:
predictor = None
init_iteration = predictor.num_total_iteration if predictor is not None else 0 init_iteration = predictor.num_total_iteration if predictor is not None else 0
# check dataset # check dataset
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment