runtime.py 7.27 KB
Newer Older
1
import asyncio
2
3
4
import pickle
import os
import pathlib
5
import shutil
6
import re
7

8
9
10
import modes.experiments as exp

class Run(object):
11
    def __init__(self, experiment, index, env, outpath, prereq=None):
12
        self.experiment = experiment
13
        self.index = index
14
15
16
        self.env = env
        self.outpath = outpath
        self.output = None
17
        self.prereq = prereq
18

19
    def name(self):
20
        return self.experiment.name + '.' + str(self.index)
21

22
23
24
25
26
27
28
29
    def prep_dirs(self):
        shutil.rmtree(self.env.workdir, ignore_errors=True)
        if self.env.create_cp:
            shutil.rmtree(self.env.cpdir, ignore_errors=True)

        pathlib.Path(self.env.workdir).mkdir(parents=True, exist_ok=True)
        pathlib.Path(self.env.cpdir).mkdir(parents=True, exist_ok=True)

30
31
32
33
34
35
36
class Runtime(object):
    def add_run(self, run):
        pass

    def start(self):
        pass

37

38
class LocalSimpleRuntime(Runtime):
39
    def __init__(self, verbose=False):
40
41
        self.runnable = []
        self.complete = []
42
        self.verbose = verbose
43
44
45
46
47
48

    def add_run(self, run):
        self.runnable.append(run)

    def start(self):
        for run in self.runnable:
49
            run.prep_dirs()
50
51
            run.output = exp.run_exp_local(run.experiment, run.env,
                    verbose=self.verbose)
52
53
            self.complete.append(run)

54
            pathlib.Path(run.outpath).parent.mkdir(parents=True, exist_ok=True)
55
56
            with open(run.outpath, 'w') as f:
                f.write(run.output.dumps())
57

58

59
class LocalParallelRuntime(Runtime):
60
    def __init__(self, cores, mem=None, verbose=False):
61
62
63
        self.runs_noprereq = []
        self.runs_prereq = []
        self.complete = set()
64
65
        self.cores = cores
        self.mem = mem
66
        self.verbose = verbose
67
68
69
70
71
72
73
74

    def add_run(self, run):
        if run.experiment.resreq_cores() > self.cores:
            raise Exception('Not enough cores available for run')

        if self.mem is not None and run.experiment.resreq_mem() > self.mem:
            raise Exception('Not enough memory available for run')

75
76
77
78
        if run.prereq is None:
            self.runs_noprereq.append(run)
        else:
            self.runs_prereq.append(run)
79
80
81

    async def do_run(self, run):
        ''' actually starts a run '''
82
        run.prep_dirs()
83

84
        await run.experiment.prepare(run.env, verbose=self.verbose)
85
        print('starting run ', run.name())
86
        run.output = await run.experiment.run(run.env, verbose=self.verbose)
87
88

        pathlib.Path(run.outpath).parent.mkdir(parents=True, exist_ok=True)
89
90
        with open(run.outpath, 'w') as f:
            f.write(run.output.dumps())
91
        print('finished run ', run.name())
92
93
94
95
96
97
98
99
100
101
102
        return run

    async def wait_completion(self):
        ''' wait for any run to terminate and return '''
        assert self.pending_jobs

        done, self.pending_jobs = await asyncio.wait(self.pending_jobs,
                return_when=asyncio.FIRST_COMPLETED)

        for run in done:
            run = await run
103
            self.complete.add(run)
104
105
106
107
108
109
110
111
112
113
114
115
116
            self.cores_used -= run.experiment.resreq_cores()
            self.mem_used -= run.experiment.resreq_mem()

    def enough_resources(self, run):
        ''' check if enough cores and mem are available for the run '''
        exp = run.experiment

        if self.cores is not None:
            enough_cores = (self.cores - self.cores_used) >= exp.resreq_cores()
        else:
            enough_cores = True

        if self.mem is not None:
117
            enough_mem = (self.mem - self.mem_used) >= exp.resreq_mem()
118
119
120
121
122
        else:
            enough_mem = True

        return enough_cores and enough_mem

123
124
125
126
127
128
    def prereq_ready(self, run):
        if run.prereq is None:
            return True

        return run.prereq in self.complete

129
130
131
132
133
134
    async def do_start(self):
        #self.completions = asyncio.Queue()
        self.cores_used = 0
        self.mem_used = 0
        self.pending_jobs = set()

135
136
        runs = self.runs_noprereq + self.runs_prereq
        for run in runs:
137
            # check if we first have to wait for memory or cores
138
            while not self.enough_resources(run):
139
140
141
                print('waiting for resources')
                await self.wait_completion()

142
143
144
145
146
            # check if we first have to wait for memory or cores
            while not self.prereq_ready(run):
                print('waiting for prereq')
                await self.wait_completion()

147
148
149
150
151
152
153
154
155
156
157
158
            self.cores_used += run.experiment.resreq_cores()
            self.mem_used += run.experiment.resreq_mem()

            job = self.do_run(run)
            self.pending_jobs.add(job)

        # wait for all runs to finish
        while self.pending_jobs:
            await self.wait_completion()

    def start(self):
        asyncio.run(self.do_start())
159
160
161
162
163
164
165
166
167
168
169
170
171
172

class SlurmRuntime(Runtime):
    def __init__(self, slurmdir, args, verbose=False, cleanup=True):
        self.runnable = []
        self.slurmdir = slurmdir
        self.args = args
        self.verbose = verbose
        self.cleanup = cleanup

    def add_run(self, run):
        self.runnable.append(run)

    def prep_run(self, run):
        exp = run.experiment
Hejing Li's avatar
Hejing Li committed
173
174
175
176
177
178
179
180
181
182
183
184
        e_idx = exp.name + f'-{run.index}' + '.exp'
        exp_path = os.path.join(self.slurmdir, e_idx)
        
        log_idx = exp.name + f'-{run.index}' + '.log'
        exp_log = os.path.join(self.slurmdir, log_idx)

        sc_idx = exp.name + f'-{run.index}' + '.sh'
        exp_script = os.path.join(self.slurmdir, sc_idx)
        print(exp_path)
        print(exp_log)
        print(exp_script)
        
185
        # write out pickled run
186
        with open(exp_path, 'wb') as f:
187
188
            run.prereq = None # we don't want to pull in the prereq too
            pickle.dump(run, f)
189
190
191
192
193
194
195
196
197
198
199
200
201
202

        # create slurm batch script
        with open(exp_script, 'w') as f:
            f.write('#!/bin/sh\n')
            f.write('#SBATCH -o %s -e %s\n' % (exp_log, exp_log))
            f.write('#SBATCH -c %d\n' % (exp.resreq_cores(),))
            f.write('#SBATCH --mem=%dM\n' % (exp.resreq_mem(),))
            f.write('#SBATCH --job-name="%s"\n' % (run.name(),))
            if exp.timeout is not None:
                h = int(exp.timeout / 3600)
                m = int((exp.timeout % 3600) / 60)
                s = int(exp.timeout % 60)
                f.write('#SBATCH --time=%02d:%02d:%02d\n' % (h, m, s))

203
204
205
206
207
            extra = ''
            if self.verbose:
                extra = '--verbose'

            f.write('python3 run.py %s --pickled %s\n' % (extra, exp_path))
208
209
210
211
212
213
214
215
216
217
            f.write('status=$?\n')
            if self.cleanup:
                f.write('rm -rf %s\n' % (run.env.workdir))
            f.write('exit $status\n')

        return exp_script

    def start(self):
        pathlib.Path(self.slurmdir).mkdir(parents=True, exist_ok=True)

218
219
        jid_re = re.compile(r'Submitted batch job ([0-9]+)')

220
        for run in self.runnable:
221
222
223
224
225
            if run.prereq is None:
                dep_cmd = ''
            else:
                dep_cmd = '--dependency=afterok:' + str(run.prereq.job_id)

226
            script = self.prep_run(run)
227

228
            stream = os.popen('sbatch %s %s' % (dep_cmd, script))
229
230
231
232
233
234
235
236
            output = stream.read()
            result = stream.close()

            if result is not None:
                raise Exception('running sbatch failed')

            m = jid_re.search(output)
            run.job_id = int(m.group(1))