trial.ts 12.3 KB
Newer Older
Lijiaoa's avatar
Lijiaoa committed
1
import * as JSON5 from 'json5';
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import {
    MetricDataRecord,
    TrialJobInfo,
    TableObj,
    TableRecord,
    Parameters,
    FinalType,
    MultipleAxes,
    SingleAxis
} from '../interface';
import {
    getFinal,
    formatAccuracy,
    metricAccuracy,
    parseMetrics,
    isArrayType,
    isNaNorInfinity,
    formatComplexTypeValue
} from '../function';
21

22
23
24
25
26
27
28
/**
 * Get a structured representation of parameters
 * @param paramObj Parameters object
 * @param space All axes from search space (or sub search space)
 * @param prefix Current namespace (to make full name for unexpected entries)
 * @returns Parsed structured parameters and unexpected entries
 */
29
30
31
32
33
function inferTrialParameters(
    paramObj: object,
    space: MultipleAxes,
    prefix: string = ''
): [Map<SingleAxis, any>, Map<string, any>] {
34
35
36
37
38
    const parameters = new Map<SingleAxis, any>();
    const unexpectedEntries = new Map<string, any>();
    for (const [k, v] of Object.entries(paramObj)) {
        // prefix can be a good fallback when corresponding item is not found in namespace
        const axisKey = space.axes.get(k);
39
        if (prefix && k === '_name') continue;
40
41
42
43
44
45
46
47
48
49
50
        if (axisKey !== undefined) {
            if (typeof v === 'object' && v._name !== undefined && axisKey.nested) {
                // nested entry
                parameters.set(axisKey, v._name);
                const subSpace = axisKey.domain.get(v._name);
                if (subSpace !== undefined) {
                    const [subParams, subUnexpected] = inferTrialParameters(v, subSpace, prefix + k + '/');
                    subParams.forEach((v, k) => parameters.set(k, v));
                    subUnexpected.forEach((v, k) => unexpectedEntries.set(k, v));
                }
            } else {
Lijiaoa's avatar
Lijiaoa committed
51
                parameters.set(axisKey, formatComplexTypeValue(v));
52
53
            }
        } else {
Lijiaoa's avatar
Lijiaoa committed
54
            unexpectedEntries.set(prefix + k, formatComplexTypeValue(v));
55
56
57
58
59
        }
    }
    return [parameters, unexpectedEntries];
}

60
61
62
class Trial implements TableObj {
    private metricsInitialized: boolean = false;
    private infoField: TrialJobInfo | undefined;
63
    public intermediates: (MetricDataRecord | undefined)[] = [];
64
    public final: MetricDataRecord | undefined;
65
66
67
68
69
70
71
72
73
74
75
76
77
    private finalAcc: number | undefined;

    constructor(info?: TrialJobInfo, metrics?: MetricDataRecord[]) {
        this.infoField = info;
        if (metrics) {
            this.updateMetrics(metrics);
        }
    }

    public compareAccuracy(otherTrial: Trial): number | undefined {
        if (!this.sortable || !otherTrial.sortable) {
            return undefined;
        }
Lijiao's avatar
Lijiao committed
78
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
79
80
81
82
        return this.finalAcc! - otherTrial.finalAcc!;
    }

    get info(): TrialJobInfo {
Lijiao's avatar
Lijiao committed
83
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
84
85
86
87
        return this.infoField!;
    }

    get intermediateMetrics(): MetricDataRecord[] {
88
        const ret: MetricDataRecord[] = [];
89
90
        for (let i = 0; i < this.intermediates.length; i++) {
            if (this.intermediates[i]) {
Lijiao's avatar
Lijiao committed
91
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
92
93
94
95
96
97
98
99
100
101
102
103
104
                ret.push(this.intermediates[i]!);
            } else {
                break;
            }
        }
        return ret;
    }

    get accuracy(): number | undefined {
        return this.finalAcc;
    }

    get sortable(): boolean {
105
        return this.metricsInitialized && this.finalAcc !== undefined && !isNaN(this.finalAcc);
106
107
    }

Lijiao's avatar
Lijiao committed
108
109
110
111
112
113
    get latestAccuracy(): number | undefined {
        if (this.accuracy !== undefined) {
            return this.accuracy;
        } else if (this.intermediates.length > 0) {
            const temp = this.intermediates[this.intermediates.length - 1];
            if (temp !== undefined) {
114
115
                if (isArrayType(parseMetrics(temp.data))) {
                    return undefined;
116
117
118
119
                } else if (
                    typeof parseMetrics(temp.data) === 'object' &&
                    parseMetrics(temp.data).hasOwnProperty('default')
                ) {
120
                    return parseMetrics(temp.data).default;
121
                } else if (typeof parseMetrics(temp.data) === 'number') {
122
123
                    return parseMetrics(temp.data);
                }
Lijiao's avatar
Lijiao committed
124
125
126
127
128
129
130
            } else {
                return undefined;
            }
        } else {
            return undefined;
        }
    }
131
132
133
134
    /* table obj start */

    get tableRecord(): TableRecord {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
135
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
136
        const duration = (endTime - this.info.startTime!) / 1000;
Lijiaoa's avatar
Lijiaoa committed
137
        let accuracy;
138
139
        if (this.acc !== undefined && this.acc.default !== undefined) {
            if (typeof this.acc.default === 'number') {
Lijiaoa's avatar
Lijiaoa committed
140
                accuracy = JSON5.parse(this.acc.default);
141
            } else {
Lijiaoa's avatar
Lijiaoa committed
142
143
144
                accuracy = this.acc.default;
            }
        }
145

146
        return {
J-shang's avatar
J-shang committed
147
            key: this.info.trialJobId,
148
            sequenceId: this.info.sequenceId,
J-shang's avatar
J-shang committed
149
            id: this.info.trialJobId,
Lijiao's avatar
Lijiao committed
150
            // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
151
152
153
154
155
            startTime: this.info.startTime!,
            endTime: this.info.endTime,
            duration,
            status: this.info.status,
            intermediateCount: this.intermediates.length,
Lijiaoa's avatar
Lijiaoa committed
156
            accuracy: accuracy,
Lijiao's avatar
Lijiao committed
157
158
            latestAccuracy: this.latestAccuracy,
            formattedLatestAccuracy: this.formatLatestAccuracy(),
159
            accDictionary: this.acc
160
161
162
163
164
165
166
167
168
169
170
171
        };
    }

    get key(): number {
        return this.info.sequenceId;
    }

    get sequenceId(): number {
        return this.info.sequenceId;
    }

    get id(): string {
J-shang's avatar
J-shang committed
172
        return this.info.trialJobId;
173
174
175
176
    }

    get duration(): number {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
177
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
178
179
180
181
182
183
184
185
        return (endTime - this.info.startTime!) / 1000;
    }

    get status(): string {
        return this.info.status;
    }

    get acc(): FinalType | undefined {
Lijiaoa's avatar
Lijiaoa committed
186
187
188
        if (this.info === undefined) {
            return undefined;
        }
189
190
191
192
        return getFinal(this.info.finalMetricData);
    }

    get description(): Parameters {
chicm-ms's avatar
chicm-ms committed
193
        const ret: Parameters = {
194
195
            parameters: {},
            intermediate: [],
196
197
198
199
200
201
202
203
204
205
206
207
            multiProgress: 1
        };
        const tempHyper = this.info.hyperParameters;
        if (tempHyper !== undefined) {
            const getPara = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            ret.multiProgress = tempHyper.length;
            if (typeof getPara === 'string') {
                ret.parameters = JSON.parse(getPara);
            } else {
                ret.parameters = getPara;
            }
        } else {
208
            ret.parameters = { error: "This trial's parameters are not available." };
209
210
211
212
213
        }
        if (this.info.logPath !== undefined) {
            ret.logPath = this.info.logPath;
        }

214
        const mediate: number[] = [];
215
        for (const items of this.intermediateMetrics) {
chicm-ms's avatar
chicm-ms committed
216
217
            if (typeof parseMetrics(items.data) === 'object') {
                mediate.push(parseMetrics(items.data).default);
218
            } else {
chicm-ms's avatar
chicm-ms committed
219
                mediate.push(parseMetrics(items.data));
220
221
222
223
224
225
            }
        }
        ret.intermediate = mediate;
        return ret;
    }

226
    public parameters(axes: MultipleAxes): Map<SingleAxis, any> {
227
        const ret = new Map<SingleAxis, any>(Array.from(axes.axes.values()).map(k => [k, null]));
Lijiaoa's avatar
Lijiaoa committed
228
        if (this.info === undefined || this.info.hyperParameters === undefined) {
229
            throw ret;
230
        } else {
Lijiaoa's avatar
Lijiaoa committed
231
            const tempHyper = this.info.hyperParameters;
232
233
234
235
            let params = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            if (typeof params === 'string') {
                params = JSON.parse(params);
            }
236
            const [updated, unexpectedEntries] = inferTrialParameters(params, axes);
237
238
239
            if (unexpectedEntries.size) {
                throw unexpectedEntries;
            }
240
241
242
243
            for (const [k, v] of updated) {
                ret.set(k, v);
            }
            return ret;
244
245
246
247
        }
    }

    public metrics(space: MultipleAxes): Map<SingleAxis, any> {
248
249
        // set default value: null
        const ret = new Map<SingleAxis, any>(Array.from(space.axes.values()).map(k => [k, null]));
250
251
252
253
254
255
256
257
        const unexpectedEntries = new Map<string, any>();
        if (this.acc === undefined) {
            return ret;
        }
        const acc = typeof this.acc === 'number' ? { default: this.acc } : this.acc;
        Object.entries(acc).forEach(item => {
            const [k, v] = item;
            const column = space.axes.get(k);
258

259
260
261
262
263
264
265
266
267
268
269
270
            if (column !== undefined) {
                ret.set(column, v);
            } else {
                unexpectedEntries.set(k, v);
            }
        });
        if (unexpectedEntries.size) {
            throw unexpectedEntries;
        }
        return ret;
    }

271
272
273
274
    get color(): string | undefined {
        return undefined;
    }

275
    public finalKeys(): string[] {
276
        if (this.acc !== undefined) {
Lijiaoa's avatar
Lijiaoa committed
277
278
279
280
            return Object.keys(this.acc);
        } else {
            return [];
        }
281
282
    }

283
284
285
    /* table obj end */

    public initialized(): boolean {
286
        return Boolean(this.infoField);
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    }

    public updateMetrics(metrics: MetricDataRecord[]): boolean {
        // parameter `metrics` must contain all known metrics of this trial
        this.metricsInitialized = true;
        const prevMetricCnt = this.intermediates.length + (this.final ? 1 : 0);
        if (metrics.length <= prevMetricCnt) {
            return false;
        }
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                this.intermediates[metric.sequence] = metric;
            } else {
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return true;
    }

    public updateLatestMetrics(metrics: MetricDataRecord[]): boolean {
        // this method is effectively identical to `updateMetrics`, but has worse performance
        this.metricsInitialized = true;
        let updated = false;
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                updated = updated || !this.intermediates[metric.sequence];
                this.intermediates[metric.sequence] = metric;
            } else {
                updated = updated || !this.final;
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return updated;
    }

    public updateTrialJobInfo(trialJobInfo: TrialJobInfo): boolean {
325
        const same = this.infoField && this.infoField.status === trialJobInfo.status;
326
327
328
329
330
331
332
333
        this.infoField = trialJobInfo;
        if (trialJobInfo.finalMetricData) {
            this.final = trialJobInfo.finalMetricData[trialJobInfo.finalMetricData.length - 1];
            this.finalAcc = metricAccuracy(this.final);
        }
        return !same;
    }

Lijiaoa's avatar
Lijiaoa committed
334
    private renderNumber(val: any): string {
335
        if (typeof val === 'number') {
Lijiaoa's avatar
Lijiaoa committed
336
337
            if (isNaNorInfinity(val)) {
                return `${val}`; // show 'NaN' or 'Infinity'
338
            } else {
339
340
341
342
343
                if (this.accuracy === undefined) {
                    return `${formatAccuracy(val)} (LATEST)`;
                } else {
                    return `${formatAccuracy(val)} (FINAL)`;
                }
344
            }
345
        } else {
Lijiaoa's avatar
Lijiaoa committed
346
347
348
349
350
            // show other types, such as {tensor: {data: }}
            return JSON.stringify(val);
        }
    }

351
352
    public formatLatestAccuracy(): string {
        // TODO: this should be private
353
        if (this.status === 'SUCCEEDED') {
354
            return this.accuracy === undefined ? '--' : this.renderNumber(this.accuracy);
Lijiaoa's avatar
Lijiaoa committed
355
356
357
358
359
        } else {
            if (this.accuracy !== undefined) {
                return this.renderNumber(this.accuracy);
            } else if (this.intermediates.length === 0) {
                return '--';
360
            } else {
Lijiaoa's avatar
Lijiaoa committed
361
362
363
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
                const latest = this.intermediates[this.intermediates.length - 1]!;
                return this.renderNumber(metricAccuracy(latest));
364
            }
365
366
367
368
369
        }
    }
}

export { Trial };