trial.ts 12.5 KB
Newer Older
Lijiaoa's avatar
Lijiaoa committed
1
import * as JSON5 from 'json5';
2
3
4
5
6
7
8
9
import {
    MetricDataRecord,
    TrialJobInfo,
    TableObj,
    TableRecord,
    Parameters,
    FinalType,
    MultipleAxes,
10
    SingleAxis
11
12
13
14
15
16
17
18
} from '../interface';
import {
    getFinal,
    formatAccuracy,
    metricAccuracy,
    parseMetrics,
    isArrayType,
    isNaNorInfinity,
19
20
    formatComplexTypeValue,
    reformatRetiariiParameter
21
} from '../function';
22

23
24
25
26
27
28
29
/**
 * Get a structured representation of parameters
 * @param paramObj Parameters object
 * @param space All axes from search space (or sub search space)
 * @param prefix Current namespace (to make full name for unexpected entries)
 * @returns Parsed structured parameters and unexpected entries
 */
30
31
32
33
34
function inferTrialParameters(
    paramObj: object,
    space: MultipleAxes,
    prefix: string = ''
): [Map<SingleAxis, any>, Map<string, any>] {
35
    const latestedParamObj: object = reformatRetiariiParameter(paramObj);
36
37
    const parameters = new Map<SingleAxis, any>();
    const unexpectedEntries = new Map<string, any>();
38
    for (const [k, v] of Object.entries(latestedParamObj)) {
39
40
        // prefix can be a good fallback when corresponding item is not found in namespace
        const axisKey = space.axes.get(k);
41
        if (prefix && k === '_name') continue;
42
43
44
45
46
47
48
49
50
51
52
        if (axisKey !== undefined) {
            if (typeof v === 'object' && v._name !== undefined && axisKey.nested) {
                // nested entry
                parameters.set(axisKey, v._name);
                const subSpace = axisKey.domain.get(v._name);
                if (subSpace !== undefined) {
                    const [subParams, subUnexpected] = inferTrialParameters(v, subSpace, prefix + k + '/');
                    subParams.forEach((v, k) => parameters.set(k, v));
                    subUnexpected.forEach((v, k) => unexpectedEntries.set(k, v));
                }
            } else {
Lijiaoa's avatar
Lijiaoa committed
53
                parameters.set(axisKey, formatComplexTypeValue(v));
54
55
            }
        } else {
Lijiaoa's avatar
Lijiaoa committed
56
            unexpectedEntries.set(prefix + k, formatComplexTypeValue(v));
57
58
59
60
61
        }
    }
    return [parameters, unexpectedEntries];
}

62
63
64
class Trial implements TableObj {
    private metricsInitialized: boolean = false;
    private infoField: TrialJobInfo | undefined;
65
    public intermediates: (MetricDataRecord | undefined)[] = [];
66
    public final: MetricDataRecord | undefined;
67
68
69
70
71
72
73
74
75
76
77
78
79
    private finalAcc: number | undefined;

    constructor(info?: TrialJobInfo, metrics?: MetricDataRecord[]) {
        this.infoField = info;
        if (metrics) {
            this.updateMetrics(metrics);
        }
    }

    public compareAccuracy(otherTrial: Trial): number | undefined {
        if (!this.sortable || !otherTrial.sortable) {
            return undefined;
        }
Lijiao's avatar
Lijiao committed
80
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
81
82
83
84
        return this.finalAcc! - otherTrial.finalAcc!;
    }

    get info(): TrialJobInfo {
Lijiao's avatar
Lijiao committed
85
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
86
87
88
89
        return this.infoField!;
    }

    get intermediateMetrics(): MetricDataRecord[] {
90
        const ret: MetricDataRecord[] = [];
91
92
        for (let i = 0; i < this.intermediates.length; i++) {
            if (this.intermediates[i]) {
Lijiao's avatar
Lijiao committed
93
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
94
95
96
97
98
99
100
101
102
103
104
105
106
                ret.push(this.intermediates[i]!);
            } else {
                break;
            }
        }
        return ret;
    }

    get accuracy(): number | undefined {
        return this.finalAcc;
    }

    get sortable(): boolean {
107
        return this.metricsInitialized && this.finalAcc !== undefined && isFinite(this.finalAcc);
108
109
    }

Lijiao's avatar
Lijiao committed
110
111
112
113
114
115
    get latestAccuracy(): number | undefined {
        if (this.accuracy !== undefined) {
            return this.accuracy;
        } else if (this.intermediates.length > 0) {
            const temp = this.intermediates[this.intermediates.length - 1];
            if (temp !== undefined) {
116
117
                if (isArrayType(parseMetrics(temp.data))) {
                    return undefined;
118
119
                } else if (
                    typeof parseMetrics(temp.data) === 'object' &&
120
                    // eslint-disable-next-line no-prototype-builtins
121
122
                    parseMetrics(temp.data).hasOwnProperty('default')
                ) {
123
                    return parseMetrics(temp.data).default;
124
                } else if (typeof parseMetrics(temp.data) === 'number') {
125
126
                    return parseMetrics(temp.data);
                }
Lijiao's avatar
Lijiao committed
127
128
129
130
131
132
133
            } else {
                return undefined;
            }
        } else {
            return undefined;
        }
    }
134
135
136
137
    /* table obj start */

    get tableRecord(): TableRecord {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
138
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
139
        const duration = (endTime - this.info.startTime!) / 1000;
Lijiaoa's avatar
Lijiaoa committed
140
        let accuracy;
141
142
        if (this.acc !== undefined && this.acc.default !== undefined) {
            if (typeof this.acc.default === 'number') {
Lijiaoa's avatar
Lijiaoa committed
143
                accuracy = JSON5.parse(this.acc.default);
144
            } else {
Lijiaoa's avatar
Lijiaoa committed
145
146
147
                accuracy = this.acc.default;
            }
        }
148

149
        return {
J-shang's avatar
J-shang committed
150
            key: this.info.trialJobId,
151
            sequenceId: this.info.sequenceId,
J-shang's avatar
J-shang committed
152
            id: this.info.trialJobId,
Lijiao's avatar
Lijiao committed
153
            // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
154
155
156
157
            startTime: this.info.startTime!,
            endTime: this.info.endTime,
            duration,
            status: this.info.status,
158
            message: this.info.message || '--',
159
            intermediateCount: this.intermediates.length,
Lijiaoa's avatar
Lijiaoa committed
160
            accuracy: accuracy,
Lijiao's avatar
Lijiao committed
161
162
            latestAccuracy: this.latestAccuracy,
            formattedLatestAccuracy: this.formatLatestAccuracy(),
163
            accDictionary: this.acc
164
165
166
167
168
169
170
171
172
173
174
175
        };
    }

    get key(): number {
        return this.info.sequenceId;
    }

    get sequenceId(): number {
        return this.info.sequenceId;
    }

    get id(): string {
J-shang's avatar
J-shang committed
176
        return this.info.trialJobId;
177
178
179
180
    }

    get duration(): number {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
181
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
182
183
184
185
186
187
188
189
        return (endTime - this.info.startTime!) / 1000;
    }

    get status(): string {
        return this.info.status;
    }

    get acc(): FinalType | undefined {
Lijiaoa's avatar
Lijiaoa committed
190
191
192
        if (this.info === undefined) {
            return undefined;
        }
193
194
195
196
        return getFinal(this.info.finalMetricData);
    }

    get description(): Parameters {
chicm-ms's avatar
chicm-ms committed
197
        const ret: Parameters = {
198
199
            parameters: {},
            intermediate: [],
200
201
202
203
204
205
206
207
208
209
210
211
            multiProgress: 1
        };
        const tempHyper = this.info.hyperParameters;
        if (tempHyper !== undefined) {
            const getPara = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            ret.multiProgress = tempHyper.length;
            if (typeof getPara === 'string') {
                ret.parameters = JSON.parse(getPara);
            } else {
                ret.parameters = getPara;
            }
        } else {
212
            ret.parameters = { error: "This trial's parameters are not available." };
213
214
215
216
217
        }
        if (this.info.logPath !== undefined) {
            ret.logPath = this.info.logPath;
        }

218
        const mediate: number[] = [];
219
        for (const items of this.intermediateMetrics) {
chicm-ms's avatar
chicm-ms committed
220
221
            if (typeof parseMetrics(items.data) === 'object') {
                mediate.push(parseMetrics(items.data).default);
222
            } else {
chicm-ms's avatar
chicm-ms committed
223
                mediate.push(parseMetrics(items.data));
224
225
226
227
228
229
            }
        }
        ret.intermediate = mediate;
        return ret;
    }

230
    public parameters(axes: MultipleAxes): Map<SingleAxis, any> {
231
        const ret = new Map<SingleAxis, any>(Array.from(axes.axes.values()).map(k => [k, null]));
Lijiaoa's avatar
Lijiaoa committed
232
        if (this.info === undefined || this.info.hyperParameters === undefined) {
233
            throw ret;
234
        } else {
Lijiaoa's avatar
Lijiaoa committed
235
            const tempHyper = this.info.hyperParameters;
236
237
238
239
            let params = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            if (typeof params === 'string') {
                params = JSON.parse(params);
            }
240
            const [updated, unexpectedEntries] = inferTrialParameters(params, axes);
241
242
243
            if (unexpectedEntries.size) {
                throw unexpectedEntries;
            }
244
245
246
247
            for (const [k, v] of updated) {
                ret.set(k, v);
            }
            return ret;
248
249
250
251
        }
    }

    public metrics(space: MultipleAxes): Map<SingleAxis, any> {
252
253
        // set default value: null
        const ret = new Map<SingleAxis, any>(Array.from(space.axes.values()).map(k => [k, null]));
254
255
256
257
258
259
260
261
        const unexpectedEntries = new Map<string, any>();
        if (this.acc === undefined) {
            return ret;
        }
        const acc = typeof this.acc === 'number' ? { default: this.acc } : this.acc;
        Object.entries(acc).forEach(item => {
            const [k, v] = item;
            const column = space.axes.get(k);
262

263
264
265
266
267
268
269
270
271
272
273
274
            if (column !== undefined) {
                ret.set(column, v);
            } else {
                unexpectedEntries.set(k, v);
            }
        });
        if (unexpectedEntries.size) {
            throw unexpectedEntries;
        }
        return ret;
    }

275
276
277
278
    get color(): string | undefined {
        return undefined;
    }

279
    public finalKeys(): string[] {
280
        if (this.acc !== undefined) {
Lijiaoa's avatar
Lijiaoa committed
281
282
283
284
            return Object.keys(this.acc);
        } else {
            return [];
        }
285
286
    }

287
288
289
    /* table obj end */

    public initialized(): boolean {
290
        return Boolean(this.infoField);
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    }

    public updateMetrics(metrics: MetricDataRecord[]): boolean {
        // parameter `metrics` must contain all known metrics of this trial
        this.metricsInitialized = true;
        const prevMetricCnt = this.intermediates.length + (this.final ? 1 : 0);
        if (metrics.length <= prevMetricCnt) {
            return false;
        }
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                this.intermediates[metric.sequence] = metric;
            } else {
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return true;
    }

    public updateLatestMetrics(metrics: MetricDataRecord[]): boolean {
        // this method is effectively identical to `updateMetrics`, but has worse performance
        this.metricsInitialized = true;
        let updated = false;
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                updated = updated || !this.intermediates[metric.sequence];
                this.intermediates[metric.sequence] = metric;
            } else {
                updated = updated || !this.final;
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return updated;
    }

    public updateTrialJobInfo(trialJobInfo: TrialJobInfo): boolean {
329
        const same = this.infoField && this.infoField.status === trialJobInfo.status;
330
331
332
333
334
335
336
337
        this.infoField = trialJobInfo;
        if (trialJobInfo.finalMetricData) {
            this.final = trialJobInfo.finalMetricData[trialJobInfo.finalMetricData.length - 1];
            this.finalAcc = metricAccuracy(this.final);
        }
        return !same;
    }

Lijiaoa's avatar
Lijiaoa committed
338
    private renderNumber(val: any): string {
339
        if (typeof val === 'number') {
Lijiaoa's avatar
Lijiaoa committed
340
341
            if (isNaNorInfinity(val)) {
                return `${val}`; // show 'NaN' or 'Infinity'
342
            } else {
343
344
345
346
347
                if (this.accuracy === undefined) {
                    return `${formatAccuracy(val)} (LATEST)`;
                } else {
                    return `${formatAccuracy(val)} (FINAL)`;
                }
348
            }
349
        } else {
Lijiaoa's avatar
Lijiaoa committed
350
351
352
353
354
            // show other types, such as {tensor: {data: }}
            return JSON.stringify(val);
        }
    }

355
356
    public formatLatestAccuracy(): string {
        // TODO: this should be private
357
        if (this.status === 'SUCCEEDED') {
358
            return this.accuracy === undefined ? '--' : this.renderNumber(this.accuracy);
Lijiaoa's avatar
Lijiaoa committed
359
360
361
362
363
        } else {
            if (this.accuracy !== undefined) {
                return this.renderNumber(this.accuracy);
            } else if (this.intermediates.length === 0) {
                return '--';
364
            } else {
Lijiaoa's avatar
Lijiaoa committed
365
366
367
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
                const latest = this.intermediates[this.intermediates.length - 1]!;
                return this.renderNumber(metricAccuracy(latest));
368
            }
369
370
371
372
373
        }
    }
}

export { Trial };