trial.ts 12.9 KB
Newer Older
Lijiaoa's avatar
Lijiaoa committed
1
import * as JSON5 from 'json5';
2
3
4
5
6
7
8
9
import {
    MetricDataRecord,
    TrialJobInfo,
    TableObj,
    TableRecord,
    Parameters,
    FinalType,
    MultipleAxes,
10
    SingleAxis
11
12
13
14
15
16
17
18
} from '../interface';
import {
    getFinal,
    formatAccuracy,
    metricAccuracy,
    parseMetrics,
    isArrayType,
    isNaNorInfinity,
19
20
    formatComplexTypeValue,
    reformatRetiariiParameter
21
} from '../function';
22

23
24
25
26
27
28
29
/**
 * Get a structured representation of parameters
 * @param paramObj Parameters object
 * @param space All axes from search space (or sub search space)
 * @param prefix Current namespace (to make full name for unexpected entries)
 * @returns Parsed structured parameters and unexpected entries
 */
30
31
32
33
34
function inferTrialParameters(
    paramObj: object,
    space: MultipleAxes,
    prefix: string = ''
): [Map<SingleAxis, any>, Map<string, any>] {
35
    const latestedParamObj: object = reformatRetiariiParameter(paramObj);
36
37
    const parameters = new Map<SingleAxis, any>();
    const unexpectedEntries = new Map<string, any>();
38
    for (const [k, v] of Object.entries(latestedParamObj)) {
39
40
        // prefix can be a good fallback when corresponding item is not found in namespace
        const axisKey = space.axes.get(k);
41
        if (prefix && k === '_name') continue;
42
43
44
45
46
47
48
49
50
51
52
        if (axisKey !== undefined) {
            if (typeof v === 'object' && v._name !== undefined && axisKey.nested) {
                // nested entry
                parameters.set(axisKey, v._name);
                const subSpace = axisKey.domain.get(v._name);
                if (subSpace !== undefined) {
                    const [subParams, subUnexpected] = inferTrialParameters(v, subSpace, prefix + k + '/');
                    subParams.forEach((v, k) => parameters.set(k, v));
                    subUnexpected.forEach((v, k) => unexpectedEntries.set(k, v));
                }
            } else {
Lijiaoa's avatar
Lijiaoa committed
53
                parameters.set(axisKey, formatComplexTypeValue(v));
54
55
            }
        } else {
Lijiaoa's avatar
Lijiaoa committed
56
            unexpectedEntries.set(prefix + k, formatComplexTypeValue(v));
57
58
59
60
61
        }
    }
    return [parameters, unexpectedEntries];
}

62
63
64
class Trial implements TableObj {
    private metricsInitialized: boolean = false;
    private infoField: TrialJobInfo | undefined;
65
    public intermediates: (MetricDataRecord | undefined)[] = [];
66
    public final: MetricDataRecord | undefined;
67
68
69
70
71
72
73
74
75
76
77
78
79
    private finalAcc: number | undefined;

    constructor(info?: TrialJobInfo, metrics?: MetricDataRecord[]) {
        this.infoField = info;
        if (metrics) {
            this.updateMetrics(metrics);
        }
    }

    public compareAccuracy(otherTrial: Trial): number | undefined {
        if (!this.sortable || !otherTrial.sortable) {
            return undefined;
        }
Lijiao's avatar
Lijiao committed
80
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
81
82
83
84
        return this.finalAcc! - otherTrial.finalAcc!;
    }

    get info(): TrialJobInfo {
Lijiao's avatar
Lijiao committed
85
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
86
87
88
89
        return this.infoField!;
    }

    get intermediateMetrics(): MetricDataRecord[] {
90
        const ret: MetricDataRecord[] = [];
91
92
        for (let i = 0; i < this.intermediates.length; i++) {
            if (this.intermediates[i]) {
Lijiao's avatar
Lijiao committed
93
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
94
95
96
97
98
99
100
101
102
103
104
105
106
                ret.push(this.intermediates[i]!);
            } else {
                break;
            }
        }
        return ret;
    }

    get accuracy(): number | undefined {
        return this.finalAcc;
    }

    get sortable(): boolean {
107
        return this.metricsInitialized && this.finalAcc !== undefined && isFinite(this.finalAcc);
108
109
    }

Lijiao's avatar
Lijiao committed
110
111
112
113
114
115
    get latestAccuracy(): number | undefined {
        if (this.accuracy !== undefined) {
            return this.accuracy;
        } else if (this.intermediates.length > 0) {
            const temp = this.intermediates[this.intermediates.length - 1];
            if (temp !== undefined) {
116
117
                if (isArrayType(parseMetrics(temp.data))) {
                    return undefined;
118
119
                } else if (
                    typeof parseMetrics(temp.data) === 'object' &&
120
                    // eslint-disable-next-line no-prototype-builtins
121
122
                    parseMetrics(temp.data).hasOwnProperty('default')
                ) {
123
                    return parseMetrics(temp.data).default;
124
                } else if (typeof parseMetrics(temp.data) === 'number') {
125
126
                    return parseMetrics(temp.data);
                }
Lijiao's avatar
Lijiao committed
127
128
129
130
131
132
133
            } else {
                return undefined;
            }
        } else {
            return undefined;
        }
    }
134
135
136
137
    /* table obj start */

    get tableRecord(): TableRecord {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
138
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
139
        const duration = (endTime - this.info.startTime!) / 1000;
Lijiaoa's avatar
Lijiaoa committed
140
        let accuracy;
141
142
        if (this.acc !== undefined && this.acc.default !== undefined) {
            if (typeof this.acc.default === 'number') {
Lijiaoa's avatar
Lijiaoa committed
143
                accuracy = JSON5.parse(this.acc.default);
144
            } else {
Lijiaoa's avatar
Lijiaoa committed
145
146
147
                accuracy = this.acc.default;
            }
        }
148

149
        return {
J-shang's avatar
J-shang committed
150
            key: this.info.trialJobId,
151
            sequenceId: this.info.sequenceId,
J-shang's avatar
J-shang committed
152
            id: this.info.trialJobId,
Lijiao's avatar
Lijiao committed
153
            // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
154
155
156
157
            startTime: this.info.startTime!,
            endTime: this.info.endTime,
            duration,
            status: this.info.status,
158
            message: this.info.message || '--',
159
            intermediateCount: this.intermediates.length,
Lijiaoa's avatar
Lijiaoa committed
160
            accuracy: accuracy,
Lijiao's avatar
Lijiao committed
161
            latestAccuracy: this.latestAccuracy,
162
            formattedLatestAccuracy: this.formatLatestAccuracy()
163
164
165
166
167
168
169
170
171
172
173
174
        };
    }

    get key(): number {
        return this.info.sequenceId;
    }

    get sequenceId(): number {
        return this.info.sequenceId;
    }

    get id(): string {
J-shang's avatar
J-shang committed
175
        return this.info.trialJobId;
176
177
178
179
    }

    get duration(): number {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
180
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
181
182
183
184
185
186
187
188
        return (endTime - this.info.startTime!) / 1000;
    }

    get status(): string {
        return this.info.status;
    }

    get acc(): FinalType | undefined {
Lijiaoa's avatar
Lijiaoa committed
189
190
191
        if (this.info === undefined) {
            return undefined;
        }
192
193
194
        return getFinal(this.info.finalMetricData);
    }

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    get accuracyNumberTypeDictKeys(): string[] {
        let accuracyTypeList: string[] = [];

        if (this.acc !== undefined) {
            for (const [item, value] of Object.entries(this.acc)) {
                if (typeof value === 'number') {
                    accuracyTypeList.push(item);
                }
            }
        } else {
            accuracyTypeList = ['default'];
        }

        return accuracyTypeList;
    }

211
    get description(): Parameters {
chicm-ms's avatar
chicm-ms committed
212
        const ret: Parameters = {
213
214
            parameters: {},
            intermediate: [],
215
216
217
218
219
220
221
222
223
224
225
226
            multiProgress: 1
        };
        const tempHyper = this.info.hyperParameters;
        if (tempHyper !== undefined) {
            const getPara = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            ret.multiProgress = tempHyper.length;
            if (typeof getPara === 'string') {
                ret.parameters = JSON.parse(getPara);
            } else {
                ret.parameters = getPara;
            }
        } else {
227
            ret.parameters = { error: "This trial's parameters are not available." };
228
229
230
231
232
        }
        if (this.info.logPath !== undefined) {
            ret.logPath = this.info.logPath;
        }

233
        const mediate: number[] = [];
234
        for (const items of this.intermediateMetrics) {
chicm-ms's avatar
chicm-ms committed
235
236
            if (typeof parseMetrics(items.data) === 'object') {
                mediate.push(parseMetrics(items.data).default);
237
            } else {
chicm-ms's avatar
chicm-ms committed
238
                mediate.push(parseMetrics(items.data));
239
240
241
242
243
244
            }
        }
        ret.intermediate = mediate;
        return ret;
    }

245
    public parameters(axes: MultipleAxes): Map<SingleAxis, any> {
246
        const ret = new Map<SingleAxis, any>(Array.from(axes.axes.values()).map(k => [k, null]));
Lijiaoa's avatar
Lijiaoa committed
247
        if (this.info === undefined || this.info.hyperParameters === undefined) {
248
            throw ret;
249
        } else {
Lijiaoa's avatar
Lijiaoa committed
250
            const tempHyper = this.info.hyperParameters;
251
252
253
254
            let params = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            if (typeof params === 'string') {
                params = JSON.parse(params);
            }
255
            const [updated, unexpectedEntries] = inferTrialParameters(params, axes);
256
257
258
            if (unexpectedEntries.size) {
                throw unexpectedEntries;
            }
259
260
261
262
            for (const [k, v] of updated) {
                ret.set(k, v);
            }
            return ret;
263
264
265
266
        }
    }

    public metrics(space: MultipleAxes): Map<SingleAxis, any> {
267
268
        // set default value: null
        const ret = new Map<SingleAxis, any>(Array.from(space.axes.values()).map(k => [k, null]));
269
270
271
272
273
274
275
276
        const unexpectedEntries = new Map<string, any>();
        if (this.acc === undefined) {
            return ret;
        }
        const acc = typeof this.acc === 'number' ? { default: this.acc } : this.acc;
        Object.entries(acc).forEach(item => {
            const [k, v] = item;
            const column = space.axes.get(k);
277

278
279
280
281
282
283
284
285
286
287
288
289
            if (column !== undefined) {
                ret.set(column, v);
            } else {
                unexpectedEntries.set(k, v);
            }
        });
        if (unexpectedEntries.size) {
            throw unexpectedEntries;
        }
        return ret;
    }

290
    public finalKeys(): string[] {
291
        if (this.acc !== undefined) {
Lijiaoa's avatar
Lijiaoa committed
292
293
294
295
            return Object.keys(this.acc);
        } else {
            return [];
        }
296
297
    }

298
299
300
    /* table obj end */

    public initialized(): boolean {
301
        return Boolean(this.infoField);
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    }

    public updateMetrics(metrics: MetricDataRecord[]): boolean {
        // parameter `metrics` must contain all known metrics of this trial
        this.metricsInitialized = true;
        const prevMetricCnt = this.intermediates.length + (this.final ? 1 : 0);
        if (metrics.length <= prevMetricCnt) {
            return false;
        }
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                this.intermediates[metric.sequence] = metric;
            } else {
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return true;
    }

    public updateLatestMetrics(metrics: MetricDataRecord[]): boolean {
        // this method is effectively identical to `updateMetrics`, but has worse performance
        this.metricsInitialized = true;
        let updated = false;
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                updated = updated || !this.intermediates[metric.sequence];
                this.intermediates[metric.sequence] = metric;
            } else {
                updated = updated || !this.final;
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return updated;
    }

    public updateTrialJobInfo(trialJobInfo: TrialJobInfo): boolean {
340
        const same = this.infoField && this.infoField.status === trialJobInfo.status;
341
342
343
344
345
346
347
348
        this.infoField = trialJobInfo;
        if (trialJobInfo.finalMetricData) {
            this.final = trialJobInfo.finalMetricData[trialJobInfo.finalMetricData.length - 1];
            this.finalAcc = metricAccuracy(this.final);
        }
        return !same;
    }

Lijiaoa's avatar
Lijiaoa committed
349
    private renderNumber(val: any): string {
350
        if (typeof val === 'number') {
Lijiaoa's avatar
Lijiaoa committed
351
352
            if (isNaNorInfinity(val)) {
                return `${val}`; // show 'NaN' or 'Infinity'
353
            } else {
354
355
356
357
358
                if (this.accuracy === undefined) {
                    return `${formatAccuracy(val)} (LATEST)`;
                } else {
                    return `${formatAccuracy(val)} (FINAL)`;
                }
359
            }
360
        } else {
Lijiaoa's avatar
Lijiaoa committed
361
362
363
364
365
            // show other types, such as {tensor: {data: }}
            return JSON.stringify(val);
        }
    }

366
367
    public formatLatestAccuracy(): string {
        // TODO: this should be private
368
        if (this.status === 'SUCCEEDED') {
369
            return this.accuracy === undefined ? '--' : this.renderNumber(this.accuracy);
Lijiaoa's avatar
Lijiaoa committed
370
371
372
373
374
        } else {
            if (this.accuracy !== undefined) {
                return this.renderNumber(this.accuracy);
            } else if (this.intermediates.length === 0) {
                return '--';
375
            } else {
Lijiaoa's avatar
Lijiaoa committed
376
377
378
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
                const latest = this.intermediates[this.intermediates.length - 1]!;
                return this.renderNumber(metricAccuracy(latest));
379
            }
380
381
382
383
384
        }
    }
}

export { Trial };