"test/srt/test_w8a8_quantization.py" did not exist on "9ba1f0976035fe7212002cac3b2b9df9f0685334"
runtime.py 6.67 KB
Newer Older
1
import asyncio
2
3
4
import pickle
import os
import pathlib
5
import re
6

7
8
9
import modes.experiments as exp

class Run(object):
10
    def __init__(self, experiment, index, env, outpath, prereq=None):
11
        self.experiment = experiment
12
        self.index = index
13
14
15
        self.env = env
        self.outpath = outpath
        self.output = None
16
        self.prereq = prereq
17

18
    def name(self):
19
        return self.experiment.name + '.' + str(self.index)
20

21
22
23
24
25
26
27
class Runtime(object):
    def add_run(self, run):
        pass

    def start(self):
        pass

28

29
class LocalSimpleRuntime(Runtime):
30
    def __init__(self, verbose=False):
31
32
        self.runnable = []
        self.complete = []
33
        self.verbose = verbose
34
35
36
37
38
39

    def add_run(self, run):
        self.runnable.append(run)

    def start(self):
        for run in self.runnable:
40
41
            run.output = exp.run_exp_local(run.experiment, run.env,
                    verbose=self.verbose)
42
43
44
45
            self.complete.append(run)

            with open(run.outpath, 'w') as f:
                f.write(run.output.dumps())
46

47

48
class LocalParallelRuntime(Runtime):
49
    def __init__(self, cores, mem=None, verbose=False):
50
51
52
        self.runs_noprereq = []
        self.runs_prereq = []
        self.complete = set()
53
54
        self.cores = cores
        self.mem = mem
55
        self.verbose = verbose
56
57
58
59
60
61
62
63

    def add_run(self, run):
        if run.experiment.resreq_cores() > self.cores:
            raise Exception('Not enough cores available for run')

        if self.mem is not None and run.experiment.resreq_mem() > self.mem:
            raise Exception('Not enough memory available for run')

64
65
66
67
        if run.prereq is None:
            self.runs_noprereq.append(run)
        else:
            self.runs_prereq.append(run)
68
69
70

    async def do_run(self, run):
        ''' actually starts a run '''
71
        await run.experiment.prepare(run.env, verbose=self.verbose)
72
        print('starting run ', run.name())
73
        run.output = await run.experiment.run(run.env, verbose=self.verbose)
74
75
        with open(run.outpath, 'w') as f:
            f.write(run.output.dumps())
76
        print('finished run ', run.name())
77
78
79
80
81
82
83
84
85
86
87
        return run

    async def wait_completion(self):
        ''' wait for any run to terminate and return '''
        assert self.pending_jobs

        done, self.pending_jobs = await asyncio.wait(self.pending_jobs,
                return_when=asyncio.FIRST_COMPLETED)

        for run in done:
            run = await run
88
            self.complete.add(run)
89
90
91
92
93
94
95
96
97
98
99
100
101
            self.cores_used -= run.experiment.resreq_cores()
            self.mem_used -= run.experiment.resreq_mem()

    def enough_resources(self, run):
        ''' check if enough cores and mem are available for the run '''
        exp = run.experiment

        if self.cores is not None:
            enough_cores = (self.cores - self.cores_used) >= exp.resreq_cores()
        else:
            enough_cores = True

        if self.mem is not None:
102
            enough_mem = (self.mem - self.mem_used) >= exp.resreq_mem()
103
104
105
106
107
        else:
            enough_mem = True

        return enough_cores and enough_mem

108
109
110
111
112
113
    def prereq_ready(self, run):
        if run.prereq is None:
            return True

        return run.prereq in self.complete

114
115
116
117
118
119
    async def do_start(self):
        #self.completions = asyncio.Queue()
        self.cores_used = 0
        self.mem_used = 0
        self.pending_jobs = set()

120
121
        runs = self.runs_noprereq + self.runs_prereq
        for run in runs:
122
            # check if we first have to wait for memory or cores
123
            while not self.enough_resources(run) or not self.prereq_ready(run):
124
125
126
                print('waiting for resources')
                await self.wait_completion()

127
128
129
130
131
            # check if we first have to wait for memory or cores
            while not self.prereq_ready(run):
                print('waiting for prereq')
                await self.wait_completion()

132
133
134
135
136
137
138
139
140
141
142
143
            self.cores_used += run.experiment.resreq_cores()
            self.mem_used += run.experiment.resreq_mem()

            job = self.do_run(run)
            self.pending_jobs.add(job)

        # wait for all runs to finish
        while self.pending_jobs:
            await self.wait_completion()

    def start(self):
        asyncio.run(self.do_start())
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

class SlurmRuntime(Runtime):
    def __init__(self, slurmdir, args, verbose=False, cleanup=True):
        self.runnable = []
        self.slurmdir = slurmdir
        self.args = args
        self.verbose = verbose
        self.cleanup = cleanup

    def add_run(self, run):
        self.runnable.append(run)

    def prep_run(self, run):
        exp = run.experiment
        exp_path = '%s/%s-%d.exp' % (self.slurmdir, exp.name, run.index)
        exp_log = '%s/%s-%d.log' % (self.slurmdir, exp.name, run.index)
        exp_script = '%s/%s-%d.sh' % (self.slurmdir, exp.name, run.index)

        # write out pickled experiment
        with open(exp_path, 'wb') as f:
            pickle.dump(exp, f)

        # create slurm batch script
        with open(exp_script, 'w') as f:
            f.write('#!/bin/sh\n')
            f.write('#SBATCH -o %s -e %s\n' % (exp_log, exp_log))
            f.write('#SBATCH -c %d\n' % (exp.resreq_cores(),))
            f.write('#SBATCH --mem=%dM\n' % (exp.resreq_mem(),))
            f.write('#SBATCH --job-name="%s"\n' % (run.name(),))
            if exp.timeout is not None:
                h = int(exp.timeout / 3600)
                m = int((exp.timeout % 3600) / 60)
                s = int(exp.timeout % 60)
                f.write('#SBATCH --time=%02d:%02d:%02d\n' % (h, m, s))

            f.write('mkdir -p %s\n' % (self.args.workdir))
            f.write(('python3 run.py --repo=%s --workdir=%s --outdir=%s '
                '--firstrun=%d --runs=1 %s\n') % (self.args.repo,
                    self.args.workdir, self.args.outdir, run.index,
                    exp_path))
            f.write('status=$?\n')
            if self.cleanup:
                f.write('rm -rf %s\n' % (run.env.workdir))
            f.write('exit $status\n')

        return exp_script

    def start(self):
        pathlib.Path(self.slurmdir).mkdir(parents=True, exist_ok=True)

194
195
        jid_re = re.compile(r'Submitted batch job ([0-9]+)')

196
        for run in self.runnable:
197
198
199
200
201
            if run.prereq is None:
                dep_cmd = ''
            else:
                dep_cmd = '--dependency=afterok:' + str(run.prereq.job_id)

202
            script = self.prep_run(run)
203

204
205
            cmd = 'sbatch ' + script
            stream = os.popen('sbatch %s %s' % (dep_cmd, script))
206
207
208
209
210
211
212
213
            output = stream.read()
            result = stream.close()

            if result is not None:
                raise Exception('running sbatch failed')

            m = jid_re.search(output)
            run.job_id = int(m.group(1))