ansible.py 4.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench Ansible Client."""

from pathlib import Path

import ansible_runner
from ansible.parsing.dataloader import DataLoader
from ansible.inventory.manager import InventoryManager

from superbench.common.utils import logger


class AnsibleClient():
    """Ansible Client class."""
    def __init__(self, config):
        """Initilize.

        Args:
            config (DictConfig): Ansible config object.
        """
        self._playbook_path = Path(__file__).parent / 'playbooks'
        self._config = {
            'private_data_dir': None,
            'inventory': None,
            'host_pattern': 'localhost',
            'cmdline': '--forks 128',
        }
        if config:
            inventory_file = getattr(config, 'host_file', None)
32
33
34
35
36
            inventory_list = getattr(config, 'host_list', None)
            if inventory_list:
                inventory_list = inventory_list.strip(',')
            if inventory_file or inventory_list:
                self._config['inventory'] = inventory_file or inventory_list
37
                self._config['host_pattern'] = 'all'
38
                inventory = InventoryManager(loader=DataLoader(), sources=inventory_file or f'{inventory_list},')
39
40
41
                host_list = inventory.get_groups_dict()['all']
                if len(host_list) > 0:
                    self._config['cmdline'] = '--forks {}'.format(len(host_list))
42
43
                if inventory_list in ['localhost', '127.0.0.1']:
                    self._config['cmdline'] += ' --connection local'
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
            username = getattr(config, 'host_username', None)
            if username:
                self._config['cmdline'] += ' --user {}'.format(username)
            password = getattr(config, 'host_password', None)
            if password:
                self._config['passwords'] = {
                    'password': password,
                    'passphrase': password,
                }
            key_file = getattr(config, 'private_key', None)
            if key_file:
                self._config['cmdline'] += ' --private-key {}'.format(key_file)
            elif password:
                self._config['cmdline'] += ' --ask-pass --ask-become-pass'
        logger.info(self._config)

    def run(self, ansible_config, sudo=False):    # pragma: no cover
        """Run Ansible runner.

        Args:
            ansible_config (dict): Ansible config dict.
            sudo (bool): Run as sudo or not. Defaults to False.

        Returns:
            int: Ansible return code.
        """
        if sudo:
            logger.info('Run as sudo ...')
            ansible_config['cmdline'] += ' --become'
        r = ansible_runner.run(**ansible_config)
        if r.rc == 0:
            logger.info('Run succeed, return code {}.'.format(r.rc))
        else:
            logger.warning('Run failed, return code {}.'.format(r.rc))
        logger.info(r.stats)
        return r.rc

Yifan Xiong's avatar
Yifan Xiong committed
81
82
83
84
85
86
87
88
89
90
91
92
    def update_mpi_config(self, ansible_config):
        """Update ansible config for mpi, run on the first host of inventory group.

        Args:
            ansible_config (dict): Ansible config dict.

        Returns:
            dict: Updated Ansible config dict.
        """
        ansible_config['host_pattern'] += '[0]'
        return ansible_config

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def get_shell_config(self, cmd):
        """Get ansible config for shell module.

        Args:
            cmd (str): Shell command for config.

        Returns:
            dict: Ansible config dict.
        """
        logger.info('Run {} on remote ...'.format(cmd))
        ansible_config = {
            **self._config,
            'module': 'shell',
            'module_args': cmd,
        }
        return ansible_config

    def get_playbook_config(self, playbook, extravars=None):
        """Get ansible config for playbook.

        Args:
            playbook (str): Playbook file name.
            extravars (dict): Extra variables in playbook. Defaults to None.

        Returns:
            dict: Ansible config dict.
        """
        logger.info('Run playbook {} ...'.format(playbook))
        ansible_config = {
            **self._config,
            'extravars': extravars,
            'playbook': str(self._playbook_path / playbook),
        }
        return ansible_config