ansible.py 4.73 KB
Newer Older
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench Ansible Client."""

6
import tempfile
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from pathlib import Path

import ansible_runner
from ansible.parsing.dataloader import DataLoader
from ansible.inventory.manager import InventoryManager

from superbench.common.utils import logger


class AnsibleClient():
    """Ansible Client class."""
    def __init__(self, config):
        """Initilize.

        Args:
            config (DictConfig): Ansible config object.
        """
        self._playbook_path = Path(__file__).parent / 'playbooks'
        self._config = {
            'host_pattern': 'localhost',
            'cmdline': '--forks 128',
        }
29
        self._head_host = None
30
        self.failure_count = 0
31
32
        if config:
            inventory_file = getattr(config, 'host_file', None)
33
34
35
36
            inventory_list = getattr(config, 'host_list', None)
            if inventory_list:
                inventory_list = inventory_list.strip(',')
            if inventory_file or inventory_list:
37
                self._config['host_pattern'] = 'all'
38
                inventory = InventoryManager(loader=DataLoader(), sources=inventory_file or f'{inventory_list},')
39
                host_list = inventory.get_hosts(pattern='all', order='sorted')
40
41
                if len(host_list) > 0:
                    self._config['cmdline'] = '--forks {}'.format(len(host_list))
42
                    self._head_host = host_list[0].get_name()
43
44
                if inventory_list in ['localhost', '127.0.0.1']:
                    self._config['cmdline'] += ' --connection local'
45
                self._config['cmdline'] += ' --inventory {}'.format(inventory_file or f'{inventory_list},')
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
            username = getattr(config, 'host_username', None)
            if username:
                self._config['cmdline'] += ' --user {}'.format(username)
            password = getattr(config, 'host_password', None)
            if password:
                self._config['passwords'] = {
                    'password': password,
                    'passphrase': password,
                }
            key_file = getattr(config, 'private_key', None)
            if key_file:
                self._config['cmdline'] += ' --private-key {}'.format(key_file)
            elif password:
                self._config['cmdline'] += ' --ask-pass --ask-become-pass'
        logger.info(self._config)

62
    def run(self, ansible_config, cancel_callback=None, sudo=False):    # pragma: no cover
63
64
65
66
        """Run Ansible runner.

        Args:
            ansible_config (dict): Ansible config dict.
67
            cancel_callback (Callable): Ansible runner cancel callback.
68
69
70
71
72
73
74
75
            sudo (bool): Run as sudo or not. Defaults to False.

        Returns:
            int: Ansible return code.
        """
        if sudo:
            logger.info('Run as sudo ...')
            ansible_config['cmdline'] += ' --become'
76
        with tempfile.TemporaryDirectory(prefix='ansible') as tmpdir:
77
            r = ansible_runner.run(private_data_dir=tmpdir, cancel_callback=cancel_callback, **ansible_config)
78
            logger.debug(r.stats)
79
80
81
        if r.rc == 0:
            logger.info('Run succeed, return code {}.'.format(r.rc))
        else:
82
            self.failure_count += 1
83
84
85
            logger.warning('Run failed, return code {}.'.format(r.rc))
        return r.rc

Yifan Xiong's avatar
Yifan Xiong committed
86
87
88
89
90
91
92
93
94
    def update_mpi_config(self, ansible_config):
        """Update ansible config for mpi, run on the first host of inventory group.

        Args:
            ansible_config (dict): Ansible config dict.

        Returns:
            dict: Updated Ansible config dict.
        """
95
96
97
98
        if not self._head_host:
            ansible_config['host_pattern'] += '[0]'
        else:
            ansible_config['host_pattern'] = self._head_host
Yifan Xiong's avatar
Yifan Xiong committed
99
100
        return ansible_config

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    def get_shell_config(self, cmd):
        """Get ansible config for shell module.

        Args:
            cmd (str): Shell command for config.

        Returns:
            dict: Ansible config dict.
        """
        logger.info('Run {} on remote ...'.format(cmd))
        ansible_config = {
            **self._config,
            'module': 'shell',
            'module_args': cmd,
        }
        return ansible_config

    def get_playbook_config(self, playbook, extravars=None):
        """Get ansible config for playbook.

        Args:
            playbook (str): Playbook file name.
            extravars (dict): Extra variables in playbook. Defaults to None.

        Returns:
            dict: Ansible config dict.
        """
        logger.info('Run playbook {} ...'.format(playbook))
        ansible_config = {
            **self._config,
            'extravars': extravars,
            'playbook': str(self._playbook_path / playbook),
        }
        return ansible_config