ssh_utils.py 1.83 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3

SparkSnail's avatar
SparkSnail committed
4
5
import os
from .common_utils import print_error
6
from .command_utils import install_package_command
7
8
9
10
11
12

def check_environment():
    '''check if paramiko is installed'''
    try:
        import paramiko
    except:
13
        install_package_command('paramiko')
chicm-ms's avatar
chicm-ms committed
14
15
        import paramiko
    return paramiko
16

SparkSnail's avatar
SparkSnail committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def copy_remote_directory_to_local(sftp, remote_path, local_path):
    '''copy remote directory to local machine'''
    try:
        os.makedirs(local_path, exist_ok=True)
        files = sftp.listdir(remote_path)
        for file in files:
            remote_full_path = os.path.join(remote_path, file)
            local_full_path = os.path.join(local_path, file)
            try:
                if sftp.listdir(remote_full_path):
                    copy_remote_directory_to_local(sftp, remote_full_path, local_full_path)
            except:
                sftp.get(remote_full_path, local_full_path)
    except Exception:
        pass

def create_ssh_sftp_client(host_ip, port, username, password):
    '''create ssh client'''
    try:
chicm-ms's avatar
chicm-ms committed
36
        paramiko = check_environment()
SparkSnail's avatar
SparkSnail committed
37
38
39
40
41
        conn = paramiko.Transport(host_ip, port)
        conn.connect(username=username, password=password)
        sftp = paramiko.SFTPClient.from_transport(conn)
        return sftp
    except Exception as exception:
42
        print_error('Create ssh client error %s\n' % exception)
SparkSnail's avatar
SparkSnail committed
43
44
45
46
47
48
49
50
51
52
53
54
55

def remove_remote_directory(sftp, directory):
    '''remove a directory in remote machine'''
    try:
        files = sftp.listdir(directory)
        for file in files:
            filepath = '/'.join([directory, file])
            try:
                sftp.remove(filepath)
            except IOError:
                remove_remote_directory(sftp, filepath)
        sftp.rmdir(directory)
    except IOError as err:
liuzhe-lz's avatar
liuzhe-lz committed
56
        print_error(err)