Unverified Commit 98a85a83 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[dask] Drop aliases of core network parameters (#3843)

* Update dask.py

* Update basic.py

* hotfix pop
parent b7ccdaf0
...@@ -298,6 +298,10 @@ class _ConfigAliases: ...@@ -298,6 +298,10 @@ class _ConfigAliases:
"local_listen_port": {"local_listen_port", "local_listen_port": {"local_listen_port",
"local_port", "local_port",
"port"}, "port"},
"machine_list_filename": {"machine_list_filename",
"machine_list_file",
"machine_list",
"mlist"},
"machines": {"machines", "machines": {"machines",
"workers", "workers",
"nodes"}, "nodes"},
...@@ -315,6 +319,8 @@ class _ConfigAliases: ...@@ -315,6 +319,8 @@ class _ConfigAliases:
"num_rounds", "num_rounds",
"num_boost_round", "num_boost_round",
"n_estimators"}, "n_estimators"},
"num_machines": {"num_machines",
"num_machine"},
"num_threads": {"num_threads", "num_threads": {"num_threads",
"num_thread", "num_thread",
"nthread", "nthread",
......
...@@ -230,7 +230,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group ...@@ -230,7 +230,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
return part # trigger error locally return part # trigger error locally
# Find locations of all parts and map them to particular Dask workers # Find locations of all parts and map them to particular Dask workers
key_to_part_dict = dict([(part.key, part) for part in parts]) key_to_part_dict = {part.key: part for part in parts}
who_has = client.who_has(parts) who_has = client.who_has(parts)
worker_map = defaultdict(list) worker_map = defaultdict(list)
for key, workers in who_has.items(): for key, workers in who_has.items():
...@@ -280,6 +280,18 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group ...@@ -280,6 +280,18 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
for num_thread_alias in _ConfigAliases.get('num_threads'): for num_thread_alias in _ConfigAliases.get('num_threads'):
params.pop(num_thread_alias, None) params.pop(num_thread_alias, None)
# machines is constructed manually, so remove it and all aliases of it from params
for machine_alias in _ConfigAliases.get('machines'):
params.pop(machine_alias, None)
# machines is constructed manually, so remove machine_list_filename and all aliases of it from params
for machine_list_filename_alias in _ConfigAliases.get('machine_list_filename'):
params.pop(machine_list_filename_alias, None)
# machines is constructed manually, so remove num_machines and all aliases of it from params
for num_machine_alias in _ConfigAliases.get('num_machines'):
params.pop(num_machine_alias, None)
# Tell each worker to train on the parts that it has locally # Tell each worker to train on the parts that it has locally
futures_classifiers = [ futures_classifiers = [
client.submit( client.submit(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment