ssh_context.py 10.8 KB
Newer Older
yuhai's avatar
yuhai committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
#!/usr/bin/env python
# coding: utf-8

import os, sys, paramiko, json, uuid, tarfile, time, stat, shutil
from glob import glob

class SSHSession (object) :
    def __init__ (self, jdata) :
        self.remote_profile = jdata
        # with open(remote_profile) as fp :
        #     self.remote_profile = json.load(fp)
        self.remote_host = self.remote_profile['hostname']
        self.remote_port = self.remote_profile.get('port', 22)
        self.remote_uname = self.remote_profile['username']
        self.remote_password = None
        if 'password' in self.remote_profile :
            self.remote_password = self.remote_profile['password']
        self.remote_workpath = self.remote_profile['work_path']
        self.ssh = None
        # postpone setup to runtime in order to be deep copied
        # self._setup_ssh(self.remote_host,
        #                 self.remote_port,
        #                 username=self.remote_uname,
        #                 password=self.remote_password)

    def ensure_alive(self,
                     max_check = 10,
                     sleep_time = 10):
        count = 1
        while not self._check_alive():
            if count == max_check:
                raise RuntimeError('cannot connect ssh after %d failures at interval %d s' %
                                   (max_check, sleep_time))
            # dlog.info('connection check failed, try to reconnect to ' + self.remote_host)
            self._setup_ssh(self.remote_host,
                            self.remote_port,
                            username=self.remote_uname,
                            password=self.remote_password)
            count += 1
            time.sleep(sleep_time)

    def _check_alive(self):
        if self.ssh is None:
            return False
        try :
            transport = self.ssh.get_transport()
            transport.send_ignore()
            return True
        except EOFError:
            return False        

    def _setup_ssh(self,
                   hostname,
                   port, 
                   username = None,
                   password = None):
        self.ssh = paramiko.SSHClient()        
        # ssh_client.load_system_host_keys()        
        self.ssh.set_missing_host_key_policy(paramiko.WarningPolicy)
        self.ssh.connect(hostname, port=port, username=username, password=password)
        assert(self.ssh.get_transport().is_active())
        transport = self.ssh.get_transport()
        transport.set_keepalive(60)

    def get_ssh_client(self) :
        if self.ssh is None:
            self._setup_ssh(self.remote_host,
                            self.remote_port,
                            username=self.remote_uname,
                            password=self.remote_password)
        return self.ssh

    def get_session_root(self) :
        return self.remote_workpath

    def close(self):
        if self.ssh is not None:
            self.ssh.close()


class SSHContext (object):
    def __init__ (self,
                  local_root,
                  ssh_session,
                  job_uuid=None,
    ) :
        assert(type(local_root) == str)
        self.local_root = os.path.abspath(local_root)
        if job_uuid:
           self.job_uuid=job_uuid
        else:
           self.job_uuid = str(uuid.uuid4())
        self.remote_root = os.path.join(ssh_session.get_session_root(), self.job_uuid)
        self.ssh_session = ssh_session
        self.ssh_session.ensure_alive()
        try:
           sftp = self.ssh_session.ssh.open_sftp() 
           sftp.mkdir(self.remote_root)
           sftp.close()
        except: 
           pass
    
    @property
    def ssh(self):
        return self.ssh_session.get_ssh_client()        

    def close(self):
        self.ssh_session.close()

    def get_job_root(self) :
        return self.remote_root
        
    def upload(self,
               job_dirs,
               local_up_files,
               dereference = True) :
        self.ssh_session.ensure_alive()
        cwd = os.getcwd()
        os.chdir(self.local_root) 
        file_list = []
        for ii in job_dirs :
            for jj in local_up_files :
                file_list.append(os.path.join(ii,jj))   
            if not file_list:
                self.block_checkcall('mkdir -p %s' % ii)
        self._put_files(file_list, dereference = dereference)
        os.chdir(cwd)

    def download(self, 
                 job_dirs,
                 remote_down_files,
                 check_exists = False,
                 mark_failure = True,
                 back_error=False) :
        self.ssh_session.ensure_alive()
        cwd = os.getcwd()
        os.chdir(self.local_root) 
        file_list = []
        for ii in job_dirs :
            for jj in remote_down_files:
                file_name = os.path.join(ii,jj)                
                if check_exists:
                    if self.check_file_exists(file_name):
                        file_list.append(file_name)
                    elif mark_failure :
                        with open(os.path.join(self.local_root, ii, 'tag_failure_download_%s' % jj), 'w') as fp: pass
                    else:
                        pass
                else:
                    file_list.append(file_name)
            if back_error:
                file_list.append(os.path.join(ii,'err*'))
        if len(file_list) > 0:
            self._get_files(file_list)
        os.chdir(cwd)
        
    def block_checkcall(self, 
                        cmd,
                        retry=3) :
        self.ssh_session.ensure_alive()
        stdin, stdout, stderr = self.ssh.exec_command(('cd %s ;' % self.remote_root) + cmd)
        exit_status = stdout.channel.recv_exit_status() 
        if exit_status != 0:
            if retry>0:
                # sleep 60 s
                print("# Get error code %d in calling %s through ssh with job: %s . message: %s" %
                      (exit_status, cmd, self.job_uuid, stderr.read().decode('utf-8')))
                print("# Sleep 60 s and retry the command...")
                time.sleep(60)
                return self.block_checkcall(cmd, retry=retry-1)
            raise RuntimeError("Get error code %d in calling %s through ssh with job: %s . message: %s" %
                               (exit_status, cmd, self.job_uuid, stderr.read().decode('utf-8')))
        return stdin, stdout, stderr    

    def block_call(self, 
                   cmd) :
        self.ssh_session.ensure_alive()
        stdin, stdout, stderr = self.ssh.exec_command(('cd %s ;' % self.remote_root) + cmd)
        exit_status = stdout.channel.recv_exit_status() 
        return exit_status, stdin, stdout, stderr

    def clean(self) :        
        self.ssh_session.ensure_alive()
        sftp = self.ssh.open_sftp()        
        self._rmtree(sftp, self.remote_root)
        sftp.close()

    def write_file(self, fname, write_str):
        self.ssh_session.ensure_alive()
        sftp = self.ssh.open_sftp()
        with sftp.open(os.path.join(self.remote_root, fname), 'w') as fp :
            fp.write(write_str)
        sftp.close()

    def read_file(self, fname):
        self.ssh_session.ensure_alive()
        sftp = self.ssh.open_sftp()
        with sftp.open(os.path.join(self.remote_root, fname), 'r') as fp:
            ret = fp.read().decode('utf-8')
        sftp.close()
        return ret

    def check_file_exists(self, fname):
        self.ssh_session.ensure_alive()
        sftp = self.ssh.open_sftp()
        try:
            sftp.stat(os.path.join(self.remote_root, fname)) 
            ret = True
        except IOError:
            ret = False
        sftp.close()
        return ret        
        
    def call(self, cmd):
        stdin, stdout, stderr = self.ssh.exec_command(cmd)
        # stdin, stdout, stderr = self.ssh.exec_command('echo $$; exec ' + cmd)
        # pid = stdout.readline().strip()
        # print(pid)
        return {'stdin':stdin, 'stdout':stdout, 'stderr':stderr}
    
    def check_finish(self, cmd_pipes):
        return cmd_pipes['stdout'].channel.exit_status_ready()
        
    def get_return(self, cmd_pipes):
        if not self.check_finish(cmd_pipes):
            return None, None, None
        else :
            retcode = cmd_pipes['stdout'].channel.recv_exit_status()
            return retcode, cmd_pipes['stdout'], cmd_pipes['stderr']

    def kill(self, cmd_pipes) :
        raise RuntimeError('dose not work! we do not know how to kill proc through paramiko.SSHClient')
        self.block_checkcall('kill -15 %s' % cmd_pipes['pid'])


    def _rmtree(self, sftp, remotepath, level=0, verbose = False):
        for f in sftp.listdir_attr(remotepath):
            rpath = os.path.join(remotepath, f.filename)
            if stat.S_ISDIR(f.st_mode):
                self._rmtree(sftp, rpath, level=(level + 1))
            else:
                rpath = os.path.join(remotepath, f.filename)
                if verbose:
                    print('# removing %s%s' % ('    ' * level, rpath)) 
                    # dlog.info('removing %s%s' % ('    ' * level, rpath))
                sftp.remove(rpath)
        if verbose: 
            print('# removing %s%s' % ('    ' * level, remotepath))
            # dlog.info('removing %s%s' % ('    ' * level, remotepath))
        sftp.rmdir(remotepath)

    def _put_files(self,
                   files,
                   dereference = True) :
        of = self.job_uuid + '.tgz'
        # local tar
        cwd = os.getcwd()
        os.chdir(self.local_root)
        if os.path.isfile(of) :
            os.remove(of)
        with tarfile.open(of, "w:gz", dereference = dereference) as tar:
            for ii in files :
                tar.add(ii)
        os.chdir(cwd)
        # trans
        from_f = os.path.join(self.local_root, of)
        to_f = os.path.join(self.remote_root, of)
        sftp = self.ssh.open_sftp()
        try:
           sftp.put(from_f, to_f)
        except FileNotFoundError:
           raise FileNotFoundError("from %s to %s Error!"%(from_f,to_f))
        # remote extract
        self.block_checkcall('tar xf %s' % of)
        # clean up
        os.remove(from_f)
        sftp.remove(to_f)
        sftp.close()

    def _get_files(self, 
                   files) :
        if not files:
            return
        of = self.job_uuid + '.tgz'
        flist = ""
        for ii in files :
            flist += " " + ii
        # remote tar
        self.block_checkcall('tar czf %s %s' % (of, flist))
        # trans
        from_f = os.path.join(self.remote_root, of)
        to_f = os.path.join(self.local_root, of)
        if os.path.isfile(to_f) :
            os.remove(to_f)
        sftp = self.ssh.open_sftp()
        sftp.get(from_f, to_f)
        # extract
        cwd = os.getcwd()
        os.chdir(self.local_root)
        with tarfile.open(of, "r:gz") as tar:
            tar.extractall()
        os.chdir(cwd)        
        # cleanup
        os.remove(to_f)
        sftp.remove(from_f)
        sftp.close()