trial.ts 12.4 KB
Newer Older
Lijiaoa's avatar
Lijiaoa committed
1
import * as JSON5 from 'json5';
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import {
    MetricDataRecord,
    TrialJobInfo,
    TableObj,
    TableRecord,
    Parameters,
    FinalType,
    MultipleAxes,
    SingleAxis
} from '../interface';
import {
    getFinal,
    formatAccuracy,
    metricAccuracy,
    parseMetrics,
    isArrayType,
    isNaNorInfinity,
    formatComplexTypeValue
} from '../function';
21

22
23
24
25
26
27
28
/**
 * Get a structured representation of parameters
 * @param paramObj Parameters object
 * @param space All axes from search space (or sub search space)
 * @param prefix Current namespace (to make full name for unexpected entries)
 * @returns Parsed structured parameters and unexpected entries
 */
29
30
31
32
33
function inferTrialParameters(
    paramObj: object,
    space: MultipleAxes,
    prefix: string = ''
): [Map<SingleAxis, any>, Map<string, any>] {
34
35
36
37
38
    const parameters = new Map<SingleAxis, any>();
    const unexpectedEntries = new Map<string, any>();
    for (const [k, v] of Object.entries(paramObj)) {
        // prefix can be a good fallback when corresponding item is not found in namespace
        const axisKey = space.axes.get(k);
39
        if (prefix && k === '_name') continue;
40
41
42
43
44
45
46
47
48
49
50
        if (axisKey !== undefined) {
            if (typeof v === 'object' && v._name !== undefined && axisKey.nested) {
                // nested entry
                parameters.set(axisKey, v._name);
                const subSpace = axisKey.domain.get(v._name);
                if (subSpace !== undefined) {
                    const [subParams, subUnexpected] = inferTrialParameters(v, subSpace, prefix + k + '/');
                    subParams.forEach((v, k) => parameters.set(k, v));
                    subUnexpected.forEach((v, k) => unexpectedEntries.set(k, v));
                }
            } else {
Lijiaoa's avatar
Lijiaoa committed
51
                parameters.set(axisKey, formatComplexTypeValue(v));
52
53
            }
        } else {
Lijiaoa's avatar
Lijiaoa committed
54
            unexpectedEntries.set(prefix + k, formatComplexTypeValue(v));
55
56
57
58
59
        }
    }
    return [parameters, unexpectedEntries];
}

60
61
62
class Trial implements TableObj {
    private metricsInitialized: boolean = false;
    private infoField: TrialJobInfo | undefined;
63
    public intermediates: (MetricDataRecord | undefined)[] = [];
64
    public final: MetricDataRecord | undefined;
65
66
67
68
69
70
71
72
73
74
75
76
77
    private finalAcc: number | undefined;

    constructor(info?: TrialJobInfo, metrics?: MetricDataRecord[]) {
        this.infoField = info;
        if (metrics) {
            this.updateMetrics(metrics);
        }
    }

    public compareAccuracy(otherTrial: Trial): number | undefined {
        if (!this.sortable || !otherTrial.sortable) {
            return undefined;
        }
Lijiao's avatar
Lijiao committed
78
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
79
80
81
82
        return this.finalAcc! - otherTrial.finalAcc!;
    }

    get info(): TrialJobInfo {
Lijiao's avatar
Lijiao committed
83
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
84
85
86
87
        return this.infoField!;
    }

    get intermediateMetrics(): MetricDataRecord[] {
88
        const ret: MetricDataRecord[] = [];
89
90
        for (let i = 0; i < this.intermediates.length; i++) {
            if (this.intermediates[i]) {
Lijiao's avatar
Lijiao committed
91
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
92
93
94
95
96
97
98
99
100
101
102
103
104
                ret.push(this.intermediates[i]!);
            } else {
                break;
            }
        }
        return ret;
    }

    get accuracy(): number | undefined {
        return this.finalAcc;
    }

    get sortable(): boolean {
105
        return this.metricsInitialized && this.finalAcc !== undefined && isFinite(this.finalAcc);
106
107
    }

Lijiao's avatar
Lijiao committed
108
109
110
111
112
113
    get latestAccuracy(): number | undefined {
        if (this.accuracy !== undefined) {
            return this.accuracy;
        } else if (this.intermediates.length > 0) {
            const temp = this.intermediates[this.intermediates.length - 1];
            if (temp !== undefined) {
114
115
                if (isArrayType(parseMetrics(temp.data))) {
                    return undefined;
116
117
                } else if (
                    typeof parseMetrics(temp.data) === 'object' &&
118
                    // eslint-disable-next-line no-prototype-builtins
119
120
                    parseMetrics(temp.data).hasOwnProperty('default')
                ) {
121
                    return parseMetrics(temp.data).default;
122
                } else if (typeof parseMetrics(temp.data) === 'number') {
123
124
                    return parseMetrics(temp.data);
                }
Lijiao's avatar
Lijiao committed
125
126
127
128
129
130
131
            } else {
                return undefined;
            }
        } else {
            return undefined;
        }
    }
132
133
134
135
    /* table obj start */

    get tableRecord(): TableRecord {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
136
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
137
        const duration = (endTime - this.info.startTime!) / 1000;
Lijiaoa's avatar
Lijiaoa committed
138
        let accuracy;
139
140
        if (this.acc !== undefined && this.acc.default !== undefined) {
            if (typeof this.acc.default === 'number') {
Lijiaoa's avatar
Lijiaoa committed
141
                accuracy = JSON5.parse(this.acc.default);
142
            } else {
Lijiaoa's avatar
Lijiaoa committed
143
144
145
                accuracy = this.acc.default;
            }
        }
146

147
        return {
J-shang's avatar
J-shang committed
148
            key: this.info.trialJobId,
149
            sequenceId: this.info.sequenceId,
J-shang's avatar
J-shang committed
150
            id: this.info.trialJobId,
Lijiao's avatar
Lijiao committed
151
            // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
152
153
154
155
            startTime: this.info.startTime!,
            endTime: this.info.endTime,
            duration,
            status: this.info.status,
156
            message: this.info.message || '--',
157
            intermediateCount: this.intermediates.length,
Lijiaoa's avatar
Lijiaoa committed
158
            accuracy: accuracy,
Lijiao's avatar
Lijiao committed
159
160
            latestAccuracy: this.latestAccuracy,
            formattedLatestAccuracy: this.formatLatestAccuracy(),
161
            accDictionary: this.acc
162
163
164
165
166
167
168
169
170
171
172
173
        };
    }

    get key(): number {
        return this.info.sequenceId;
    }

    get sequenceId(): number {
        return this.info.sequenceId;
    }

    get id(): string {
J-shang's avatar
J-shang committed
174
        return this.info.trialJobId;
175
176
177
178
    }

    get duration(): number {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
179
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
180
181
182
183
184
185
186
187
        return (endTime - this.info.startTime!) / 1000;
    }

    get status(): string {
        return this.info.status;
    }

    get acc(): FinalType | undefined {
Lijiaoa's avatar
Lijiaoa committed
188
189
190
        if (this.info === undefined) {
            return undefined;
        }
191
192
193
194
        return getFinal(this.info.finalMetricData);
    }

    get description(): Parameters {
chicm-ms's avatar
chicm-ms committed
195
        const ret: Parameters = {
196
197
            parameters: {},
            intermediate: [],
198
199
200
201
202
203
204
205
206
207
208
209
            multiProgress: 1
        };
        const tempHyper = this.info.hyperParameters;
        if (tempHyper !== undefined) {
            const getPara = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            ret.multiProgress = tempHyper.length;
            if (typeof getPara === 'string') {
                ret.parameters = JSON.parse(getPara);
            } else {
                ret.parameters = getPara;
            }
        } else {
210
            ret.parameters = { error: "This trial's parameters are not available." };
211
212
213
214
215
        }
        if (this.info.logPath !== undefined) {
            ret.logPath = this.info.logPath;
        }

216
        const mediate: number[] = [];
217
        for (const items of this.intermediateMetrics) {
chicm-ms's avatar
chicm-ms committed
218
219
            if (typeof parseMetrics(items.data) === 'object') {
                mediate.push(parseMetrics(items.data).default);
220
            } else {
chicm-ms's avatar
chicm-ms committed
221
                mediate.push(parseMetrics(items.data));
222
223
224
225
226
227
            }
        }
        ret.intermediate = mediate;
        return ret;
    }

228
    public parameters(axes: MultipleAxes): Map<SingleAxis, any> {
229
        const ret = new Map<SingleAxis, any>(Array.from(axes.axes.values()).map(k => [k, null]));
Lijiaoa's avatar
Lijiaoa committed
230
        if (this.info === undefined || this.info.hyperParameters === undefined) {
231
            throw ret;
232
        } else {
Lijiaoa's avatar
Lijiaoa committed
233
            const tempHyper = this.info.hyperParameters;
234
235
236
237
            let params = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            if (typeof params === 'string') {
                params = JSON.parse(params);
            }
238
            const [updated, unexpectedEntries] = inferTrialParameters(params, axes);
239
240
241
            if (unexpectedEntries.size) {
                throw unexpectedEntries;
            }
242
243
244
245
            for (const [k, v] of updated) {
                ret.set(k, v);
            }
            return ret;
246
247
248
249
        }
    }

    public metrics(space: MultipleAxes): Map<SingleAxis, any> {
250
251
        // set default value: null
        const ret = new Map<SingleAxis, any>(Array.from(space.axes.values()).map(k => [k, null]));
252
253
254
255
256
257
258
259
        const unexpectedEntries = new Map<string, any>();
        if (this.acc === undefined) {
            return ret;
        }
        const acc = typeof this.acc === 'number' ? { default: this.acc } : this.acc;
        Object.entries(acc).forEach(item => {
            const [k, v] = item;
            const column = space.axes.get(k);
260

261
262
263
264
265
266
267
268
269
270
271
272
            if (column !== undefined) {
                ret.set(column, v);
            } else {
                unexpectedEntries.set(k, v);
            }
        });
        if (unexpectedEntries.size) {
            throw unexpectedEntries;
        }
        return ret;
    }

273
274
275
276
    get color(): string | undefined {
        return undefined;
    }

277
    public finalKeys(): string[] {
278
        if (this.acc !== undefined) {
Lijiaoa's avatar
Lijiaoa committed
279
280
281
282
            return Object.keys(this.acc);
        } else {
            return [];
        }
283
284
    }

285
286
287
    /* table obj end */

    public initialized(): boolean {
288
        return Boolean(this.infoField);
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    }

    public updateMetrics(metrics: MetricDataRecord[]): boolean {
        // parameter `metrics` must contain all known metrics of this trial
        this.metricsInitialized = true;
        const prevMetricCnt = this.intermediates.length + (this.final ? 1 : 0);
        if (metrics.length <= prevMetricCnt) {
            return false;
        }
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                this.intermediates[metric.sequence] = metric;
            } else {
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return true;
    }

    public updateLatestMetrics(metrics: MetricDataRecord[]): boolean {
        // this method is effectively identical to `updateMetrics`, but has worse performance
        this.metricsInitialized = true;
        let updated = false;
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                updated = updated || !this.intermediates[metric.sequence];
                this.intermediates[metric.sequence] = metric;
            } else {
                updated = updated || !this.final;
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return updated;
    }

    public updateTrialJobInfo(trialJobInfo: TrialJobInfo): boolean {
327
        const same = this.infoField && this.infoField.status === trialJobInfo.status;
328
329
330
331
332
333
334
335
        this.infoField = trialJobInfo;
        if (trialJobInfo.finalMetricData) {
            this.final = trialJobInfo.finalMetricData[trialJobInfo.finalMetricData.length - 1];
            this.finalAcc = metricAccuracy(this.final);
        }
        return !same;
    }

Lijiaoa's avatar
Lijiaoa committed
336
    private renderNumber(val: any): string {
337
        if (typeof val === 'number') {
Lijiaoa's avatar
Lijiaoa committed
338
339
            if (isNaNorInfinity(val)) {
                return `${val}`; // show 'NaN' or 'Infinity'
340
            } else {
341
342
343
344
345
                if (this.accuracy === undefined) {
                    return `${formatAccuracy(val)} (LATEST)`;
                } else {
                    return `${formatAccuracy(val)} (FINAL)`;
                }
346
            }
347
        } else {
Lijiaoa's avatar
Lijiaoa committed
348
349
350
351
352
            // show other types, such as {tensor: {data: }}
            return JSON.stringify(val);
        }
    }

353
354
    public formatLatestAccuracy(): string {
        // TODO: this should be private
355
        if (this.status === 'SUCCEEDED') {
356
            return this.accuracy === undefined ? '--' : this.renderNumber(this.accuracy);
Lijiaoa's avatar
Lijiaoa committed
357
358
359
360
361
        } else {
            if (this.accuracy !== undefined) {
                return this.renderNumber(this.accuracy);
            } else if (this.intermediates.length === 0) {
                return '--';
362
            } else {
Lijiaoa's avatar
Lijiaoa committed
363
364
365
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
                const latest = this.intermediates[this.intermediates.length - 1]!;
                return this.renderNumber(metricAccuracy(latest));
366
            }
367
368
369
370
371
        }
    }
}

export { Trial };