Unverified Commit cb772860 authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Bug - Fix issue that the root mpi rank may not be the first in the hostfile (#270)

**Description**
Launch mpi on the sorted first host in the hostfile.
parent bcf6ea37
......@@ -26,6 +26,7 @@ def __init__(self, config):
'host_pattern': 'localhost',
'cmdline': '--forks 128',
}
self._head_host = None
if config:
inventory_file = getattr(config, 'host_file', None)
inventory_list = getattr(config, 'host_list', None)
......@@ -34,9 +35,10 @@ def __init__(self, config):
if inventory_file or inventory_list:
self._config['host_pattern'] = 'all'
inventory = InventoryManager(loader=DataLoader(), sources=inventory_file or f'{inventory_list},')
host_list = inventory.get_groups_dict()['all']
host_list = inventory.get_hosts(pattern='all', order='sorted')
if len(host_list) > 0:
self._config['cmdline'] = '--forks {}'.format(len(host_list))
self._head_host = host_list[0].get_name()
if inventory_list in ['localhost', '127.0.0.1']:
self._config['cmdline'] += ' --connection local'
self._config['cmdline'] += ' --inventory {}'.format(inventory_file or f'{inventory_list},')
......@@ -87,7 +89,10 @@ def update_mpi_config(self, ansible_config):
Returns:
dict: Updated Ansible config dict.
"""
ansible_config['host_pattern'] += '[0]'
if not self._head_host:
ansible_config['host_pattern'] += '[0]'
else:
ansible_config['host_pattern'] = self._head_host
return ansible_config
def get_shell_config(self, cmd):
......
......@@ -38,10 +38,12 @@ def setUp(self):
'host_password': 'pass',
})
)
_, self.test_mpi_host_file = tempfile.mkstemp()
def tearDown(self):
"""Hook method for deconstructing the test fixture after testing it."""
Path(self.host_file).unlink()
Path(self.test_mpi_host_file).unlink()
def test_init_config(self):
"""Test initial config of client."""
......@@ -61,6 +63,63 @@ def test_update_mpi_config(self):
self.assertDictEqual(
self.ansible_client.update_mpi_config(self.ansible_client._config), {
**self.ansible_client._config,
'host_pattern': '10.0.0.10',
}
)
def test_update_mpi_config_for_different_inventory(self):
"""Test update_mpi_config of client for different inventory."""
# Test for out-of-order
with open(self.test_mpi_host_file, 'w') as fd:
fd.write('all:\n hosts:\n 10.0.0.12:\n 10.0.0.11:\n 10.0.0.10:\n 10.0.0.13:\n 10.0.0.14:\n')
mess_hosts = AnsibleClient(
OmegaConf.create(
{
'host_file': self.test_mpi_host_file,
'host_username': 'user',
'host_password': 'pass',
}
)
)
self.assertDictEqual(
mess_hosts.update_mpi_config(mess_hosts._config), {
**mess_hosts._config,
'host_pattern': '10.0.0.10',
}
)
# Test for localhost
with open(self.test_mpi_host_file, 'w') as fd:
fd.write('all:\n hosts:\n localhost:\n')
localhost = AnsibleClient(
OmegaConf.create(
{
'host_file': self.test_mpi_host_file,
'host_username': 'user',
'host_password': 'pass',
}
)
)
self.assertDictEqual(
localhost.update_mpi_config(localhost._config), {
**localhost._config,
'host_pattern': 'localhost',
}
)
# Test for no host
with open(self.test_mpi_host_file, 'w') as fd:
fd.write('all:\n hosts:\n')
no_hosts = AnsibleClient(
OmegaConf.create(
{
'host_file': self.test_mpi_host_file,
'host_username': 'user',
'host_password': 'pass',
}
)
)
self.assertDictEqual(
no_hosts.update_mpi_config(no_hosts._config), {
**no_hosts._config,
'host_pattern': 'all[0]',
}
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment