ansible.py 4.54 KB
Newer Older
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench Ansible Client."""

6
import tempfile
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from pathlib import Path

import ansible_runner
from ansible.parsing.dataloader import DataLoader
from ansible.inventory.manager import InventoryManager

from superbench.common.utils import logger


class AnsibleClient():
    """Ansible Client class."""
    def __init__(self, config):
        """Initilize.

        Args:
            config (DictConfig): Ansible config object.
        """
        self._playbook_path = Path(__file__).parent / 'playbooks'
        self._config = {
            'host_pattern': 'localhost',
            'cmdline': '--forks 128',
        }
29
        self._head_host = None
30
31
        if config:
            inventory_file = getattr(config, 'host_file', None)
32
33
34
35
            inventory_list = getattr(config, 'host_list', None)
            if inventory_list:
                inventory_list = inventory_list.strip(',')
            if inventory_file or inventory_list:
36
                self._config['host_pattern'] = 'all'
37
                inventory = InventoryManager(loader=DataLoader(), sources=inventory_file or f'{inventory_list},')
38
                host_list = inventory.get_hosts(pattern='all', order='sorted')
39
40
                if len(host_list) > 0:
                    self._config['cmdline'] = '--forks {}'.format(len(host_list))
41
                    self._head_host = host_list[0].get_name()
42
43
                if inventory_list in ['localhost', '127.0.0.1']:
                    self._config['cmdline'] += ' --connection local'
44
                self._config['cmdline'] += ' --inventory {}'.format(inventory_file or f'{inventory_list},')
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
            username = getattr(config, 'host_username', None)
            if username:
                self._config['cmdline'] += ' --user {}'.format(username)
            password = getattr(config, 'host_password', None)
            if password:
                self._config['passwords'] = {
                    'password': password,
                    'passphrase': password,
                }
            key_file = getattr(config, 'private_key', None)
            if key_file:
                self._config['cmdline'] += ' --private-key {}'.format(key_file)
            elif password:
                self._config['cmdline'] += ' --ask-pass --ask-become-pass'
        logger.info(self._config)

    def run(self, ansible_config, sudo=False):    # pragma: no cover
        """Run Ansible runner.

        Args:
            ansible_config (dict): Ansible config dict.
            sudo (bool): Run as sudo or not. Defaults to False.

        Returns:
            int: Ansible return code.
        """
        if sudo:
            logger.info('Run as sudo ...')
            ansible_config['cmdline'] += ' --become'
74
75
76
        with tempfile.TemporaryDirectory(prefix='ansible') as tmpdir:
            r = ansible_runner.run(private_data_dir=tmpdir, **ansible_config)
            logger.debug(r.stats)
77
78
79
80
81
82
        if r.rc == 0:
            logger.info('Run succeed, return code {}.'.format(r.rc))
        else:
            logger.warning('Run failed, return code {}.'.format(r.rc))
        return r.rc

Yifan Xiong's avatar
Yifan Xiong committed
83
84
85
86
87
88
89
90
91
    def update_mpi_config(self, ansible_config):
        """Update ansible config for mpi, run on the first host of inventory group.

        Args:
            ansible_config (dict): Ansible config dict.

        Returns:
            dict: Updated Ansible config dict.
        """
92
93
94
95
        if not self._head_host:
            ansible_config['host_pattern'] += '[0]'
        else:
            ansible_config['host_pattern'] = self._head_host
Yifan Xiong's avatar
Yifan Xiong committed
96
97
        return ansible_config

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def get_shell_config(self, cmd):
        """Get ansible config for shell module.

        Args:
            cmd (str): Shell command for config.

        Returns:
            dict: Ansible config dict.
        """
        logger.info('Run {} on remote ...'.format(cmd))
        ansible_config = {
            **self._config,
            'module': 'shell',
            'module_args': cmd,
        }
        return ansible_config

    def get_playbook_config(self, playbook, extravars=None):
        """Get ansible config for playbook.

        Args:
            playbook (str): Playbook file name.
            extravars (dict): Extra variables in playbook. Defaults to None.

        Returns:
            dict: Ansible config dict.
        """
        logger.info('Run playbook {} ...'.format(playbook))
        ansible_config = {
            **self._config,
            'extravars': extravars,
            'playbook': str(self._playbook_path / playbook),
        }
        return ansible_config