Unverified Commit 113da3af authored by Frank Fineis's avatar Frank Fineis Committed by GitHub
Browse files

[dask] [python] Store co-local data parts as dicts instead of lists (#3853)

* store data parts in dict instead of list

* simplify weight/group parts dict assignment
parent 02e4b791
...@@ -136,25 +136,24 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re ...@@ -136,25 +136,24 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
is_ranker = issubclass(model_factory, LGBMRanker) is_ranker = issubclass(model_factory, LGBMRanker)
# Concatenate many parts into one # Concatenate many parts into one
parts = tuple(zip(*list_of_parts)) data = _concat([x['data'] for x in list_of_parts])
data = _concat(parts[0]) label = _concat([x['label'] for x in list_of_parts])
label = _concat(parts[1])
if 'weight' in list_of_parts[0]:
weight = _concat([x['weight'] for x in list_of_parts])
else:
weight = None
if 'group' in list_of_parts[0]:
group = _concat([x['group'] for x in list_of_parts])
else:
group = None
try: try:
model = model_factory(**params) model = model_factory(**params)
if is_ranker: if is_ranker:
group = _concat(parts[-1])
if len(parts) == 4:
weight = _concat(parts[2])
else:
weight = None
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs) model.fit(data, y=label, sample_weight=weight, group=group, **kwargs)
else: else:
if len(parts) == 3:
weight = _concat(parts[2])
else:
weight = None
model.fit(data, y=label, sample_weight=weight, **kwargs) model.fit(data, y=label, sample_weight=weight, **kwargs)
finally: finally:
...@@ -197,29 +196,20 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group ...@@ -197,29 +196,20 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
""" """
params = deepcopy(params) params = deepcopy(params)
# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality # Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
data_parts = _split_to_parts(data=data, is_matrix=True) data_parts = _split_to_parts(data=data, is_matrix=True)
label_parts = _split_to_parts(data=label, is_matrix=False) label_parts = _split_to_parts(data=label, is_matrix=False)
parts = [{'data': x, 'label': y} for (x, y) in zip(data_parts, label_parts)]
if sample_weight is not None: if sample_weight is not None:
weight_parts = _split_to_parts(data=sample_weight, is_matrix=False) weight_parts = _split_to_parts(data=sample_weight, is_matrix=False)
else: for i in range(len(parts)):
weight_parts = None parts[i]['weight'] = weight_parts[i]
if group is not None: if group is not None:
group_parts = _split_to_parts(data=group, is_matrix=False) group_parts = _split_to_parts(data=group, is_matrix=False)
else: for i in range(len(parts)):
group_parts = None parts[i]['group'] = group_parts[i]
# choose between four options of (sample_weight, group) being (un)specified
if weight_parts is None and group_parts is None:
parts = zip(data_parts, label_parts)
elif weight_parts is not None and group_parts is None:
parts = zip(data_parts, label_parts, weight_parts)
elif weight_parts is None and group_parts is not None:
parts = zip(data_parts, label_parts, group_parts)
else:
parts = zip(data_parts, label_parts, weight_parts, group_parts)
# Start computation in the background # Start computation in the background
parts = list(map(delayed, parts)) parts = list(map(delayed, parts))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment