Unverified Commit 33d90f46 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[ci] [python-package] resolve remaining mypy errors in dask.py (#5858)

parent 11e17f39
...@@ -576,13 +576,48 @@ def _train( ...@@ -576,13 +576,48 @@ def _train(
# pad eval sets when they come in different sizes. # pad eval sets when they come in different sizes.
n_largest_eval_parts = max(x[0].npartitions for x in eval_set) n_largest_eval_parts = max(x[0].npartitions for x in eval_set)
eval_sets = defaultdict(list) eval_sets: Dict[
int,
List[
Union[
_DatasetNames,
Tuple[
List[Optional[_DaskMatrixLike]],
List[Optional[_DaskVectorLike]]
]
]
]
] = defaultdict(list)
if eval_sample_weight: if eval_sample_weight:
eval_sample_weights = defaultdict(list) eval_sample_weights: Dict[
int,
List[
Union[
_DatasetNames,
List[Optional[_DaskVectorLike]]
]
]
] = defaultdict(list)
if eval_group: if eval_group:
eval_groups = defaultdict(list) eval_groups: Dict[
int,
List[
Union[
_DatasetNames,
List[Optional[_DaskVectorLike]]
]
]
] = defaultdict(list)
if eval_init_score: if eval_init_score:
eval_init_scores = defaultdict(list) eval_init_scores: Dict[
int,
List[
Union[
_DatasetNames,
List[Optional[_DaskMatrixLike]]
]
]
] = defaultdict(list)
for i, (X_eval, y_eval) in enumerate(eval_set): for i, (X_eval, y_eval) in enumerate(eval_set):
n_this_eval_parts = X_eval.npartitions n_this_eval_parts = X_eval.npartitions
...@@ -610,8 +645,8 @@ def _train( ...@@ -610,8 +645,8 @@ def _train(
eval_sets[parts_idx].append(([x_e], [y_e])) eval_sets[parts_idx].append(([x_e], [y_e]))
else: else:
# append additional chunks of this eval set to this part. # append additional chunks of this eval set to this part.
eval_sets[parts_idx][-1][0].append(x_e) eval_sets[parts_idx][-1][0].append(x_e) # type: ignore[index, union-attr]
eval_sets[parts_idx][-1][1].append(y_e) eval_sets[parts_idx][-1][1].append(y_e) # type: ignore[index, union-attr]
if eval_sample_weight: if eval_sample_weight:
if eval_sample_weight[i] is sample_weight: if eval_sample_weight[i] is sample_weight:
...@@ -631,7 +666,7 @@ def _train( ...@@ -631,7 +666,7 @@ def _train(
if j < n_parts: if j < n_parts:
eval_sample_weights[parts_idx].append([w_e]) eval_sample_weights[parts_idx].append([w_e])
else: else:
eval_sample_weights[parts_idx][-1].append(w_e) eval_sample_weights[parts_idx][-1].append(w_e) # type: ignore[union-attr]
if eval_init_score: if eval_init_score:
if eval_init_score[i] is init_score: if eval_init_score[i] is init_score:
...@@ -649,7 +684,7 @@ def _train( ...@@ -649,7 +684,7 @@ def _train(
if j < n_parts: if j < n_parts:
eval_init_scores[parts_idx].append([init_score_e]) eval_init_scores[parts_idx].append([init_score_e])
else: else:
eval_init_scores[parts_idx][-1].append(init_score_e) eval_init_scores[parts_idx][-1].append(init_score_e) # type: ignore[union-attr]
if eval_group: if eval_group:
if eval_group[i] is group: if eval_group[i] is group:
...@@ -667,7 +702,7 @@ def _train( ...@@ -667,7 +702,7 @@ def _train(
if j < n_parts: if j < n_parts:
eval_groups[parts_idx].append([g_e]) eval_groups[parts_idx].append([g_e])
else: else:
eval_groups[parts_idx][-1].append(g_e) eval_groups[parts_idx][-1].append(g_e) # type: ignore[union-attr]
# assign sub-eval_set components to worker parts. # assign sub-eval_set components to worker parts.
for parts_idx, e_set in eval_sets.items(): for parts_idx, e_set in eval_sets.items():
...@@ -686,7 +721,8 @@ def _train( ...@@ -686,7 +721,8 @@ def _train(
for part in parts: for part in parts:
if part.status == 'error': # type: ignore if part.status == 'error': # type: ignore
return part # trigger error locally # trigger error locally
return part # type: ignore[return-value]
# Find locations of all parts and map them to particular Dask workers # Find locations of all parts and map them to particular Dask workers
key_to_part_dict = {part.key: part for part in parts} # type: ignore key_to_part_dict = {part.key: part for part in parts} # type: ignore
...@@ -701,7 +737,7 @@ def _train( ...@@ -701,7 +737,7 @@ def _train(
for worker in worker_map: for worker in worker_map:
has_eval_set = False has_eval_set = False
for part in worker_map[worker]: for part in worker_map[worker]:
if 'eval_set' in part.result(): if 'eval_set' in part.result(): # type: ignore[attr-defined]
has_eval_set = True has_eval_set = True
break break
...@@ -1002,7 +1038,7 @@ def _predict( ...@@ -1002,7 +1038,7 @@ def _predict(
**kwargs, **kwargs,
) )
pred_row = predict_fn(data_row) pred_row = predict_fn(data_row)
chunks = (data.chunks[0],) chunks: Tuple[int, ...] = (data.chunks[0],)
map_blocks_kwargs = {} map_blocks_kwargs = {}
if len(pred_row.shape) > 1: if len(pred_row.shape) > 1:
chunks += (pred_row.shape[1],) chunks += (pred_row.shape[1],)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment