"magic_pdf/vscode:/vscode.git/clone" did not exist on "2c4a586eb1d1c6ad65b1fc322a1d9d28795b1484"
trial.ts 12.4 KB
Newer Older
Lijiaoa's avatar
Lijiaoa committed
1
import * as JSON5 from 'json5';
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import {
    MetricDataRecord,
    TrialJobInfo,
    TableObj,
    TableRecord,
    Parameters,
    FinalType,
    MultipleAxes,
    SingleAxis
} from '../interface';
import {
    getFinal,
    formatAccuracy,
    metricAccuracy,
    parseMetrics,
    isArrayType,
    isNaNorInfinity,
    formatComplexTypeValue
} from '../function';
21

22
23
24
25
26
27
28
/**
 * Get a structured representation of parameters
 * @param paramObj Parameters object
 * @param space All axes from search space (or sub search space)
 * @param prefix Current namespace (to make full name for unexpected entries)
 * @returns Parsed structured parameters and unexpected entries
 */
29
30
31
32
33
function inferTrialParameters(
    paramObj: object,
    space: MultipleAxes,
    prefix: string = ''
): [Map<SingleAxis, any>, Map<string, any>] {
34
35
36
37
38
    const parameters = new Map<SingleAxis, any>();
    const unexpectedEntries = new Map<string, any>();
    for (const [k, v] of Object.entries(paramObj)) {
        // prefix can be a good fallback when corresponding item is not found in namespace
        const axisKey = space.axes.get(k);
39
        if (prefix && k === '_name') continue;
40
41
42
43
44
45
46
47
48
49
50
        if (axisKey !== undefined) {
            if (typeof v === 'object' && v._name !== undefined && axisKey.nested) {
                // nested entry
                parameters.set(axisKey, v._name);
                const subSpace = axisKey.domain.get(v._name);
                if (subSpace !== undefined) {
                    const [subParams, subUnexpected] = inferTrialParameters(v, subSpace, prefix + k + '/');
                    subParams.forEach((v, k) => parameters.set(k, v));
                    subUnexpected.forEach((v, k) => unexpectedEntries.set(k, v));
                }
            } else {
Lijiaoa's avatar
Lijiaoa committed
51
                parameters.set(axisKey, formatComplexTypeValue(v));
52
53
            }
        } else {
Lijiaoa's avatar
Lijiaoa committed
54
            unexpectedEntries.set(prefix + k, formatComplexTypeValue(v));
55
56
57
58
59
        }
    }
    return [parameters, unexpectedEntries];
}

60
61
62
class Trial implements TableObj {
    private metricsInitialized: boolean = false;
    private infoField: TrialJobInfo | undefined;
63
    public intermediates: (MetricDataRecord | undefined)[] = [];
64
    public final: MetricDataRecord | undefined;
65
66
67
68
69
70
71
72
73
74
75
76
77
    private finalAcc: number | undefined;

    constructor(info?: TrialJobInfo, metrics?: MetricDataRecord[]) {
        this.infoField = info;
        if (metrics) {
            this.updateMetrics(metrics);
        }
    }

    public compareAccuracy(otherTrial: Trial): number | undefined {
        if (!this.sortable || !otherTrial.sortable) {
            return undefined;
        }
Lijiao's avatar
Lijiao committed
78
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
79
80
81
82
        return this.finalAcc! - otherTrial.finalAcc!;
    }

    get info(): TrialJobInfo {
Lijiao's avatar
Lijiao committed
83
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
84
85
86
87
        return this.infoField!;
    }

    get intermediateMetrics(): MetricDataRecord[] {
88
        const ret: MetricDataRecord[] = [];
89
90
        for (let i = 0; i < this.intermediates.length; i++) {
            if (this.intermediates[i]) {
Lijiao's avatar
Lijiao committed
91
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
92
93
94
95
96
97
98
99
100
101
102
103
104
                ret.push(this.intermediates[i]!);
            } else {
                break;
            }
        }
        return ret;
    }

    get accuracy(): number | undefined {
        return this.finalAcc;
    }

    get sortable(): boolean {
105
        return this.metricsInitialized && this.finalAcc !== undefined && !isNaN(this.finalAcc);
106
107
    }

Lijiao's avatar
Lijiao committed
108
109
110
111
112
113
    get latestAccuracy(): number | undefined {
        if (this.accuracy !== undefined) {
            return this.accuracy;
        } else if (this.intermediates.length > 0) {
            const temp = this.intermediates[this.intermediates.length - 1];
            if (temp !== undefined) {
114
115
                if (isArrayType(parseMetrics(temp.data))) {
                    return undefined;
116
117
118
119
                } else if (
                    typeof parseMetrics(temp.data) === 'object' &&
                    parseMetrics(temp.data).hasOwnProperty('default')
                ) {
120
                    return parseMetrics(temp.data).default;
121
                } else if (typeof parseMetrics(temp.data) === 'number') {
122
123
                    return parseMetrics(temp.data);
                }
Lijiao's avatar
Lijiao committed
124
125
126
127
128
129
130
            } else {
                return undefined;
            }
        } else {
            return undefined;
        }
    }
131
132
133
134
    /* table obj start */

    get tableRecord(): TableRecord {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
135
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
136
        const duration = (endTime - this.info.startTime!) / 1000;
Lijiaoa's avatar
Lijiaoa committed
137
        let accuracy;
138
139
        if (this.acc !== undefined && this.acc.default !== undefined) {
            if (typeof this.acc.default === 'number') {
Lijiaoa's avatar
Lijiaoa committed
140
                accuracy = JSON5.parse(this.acc.default);
141
            } else {
Lijiaoa's avatar
Lijiaoa committed
142
143
144
                accuracy = this.acc.default;
            }
        }
145

146
        return {
J-shang's avatar
J-shang committed
147
            key: this.info.trialJobId,
148
            sequenceId: this.info.sequenceId,
J-shang's avatar
J-shang committed
149
            id: this.info.trialJobId,
Lijiao's avatar
Lijiao committed
150
            // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
151
152
153
154
            startTime: this.info.startTime!,
            endTime: this.info.endTime,
            duration,
            status: this.info.status,
155
            message: this.info.message || '--',
156
            intermediateCount: this.intermediates.length,
Lijiaoa's avatar
Lijiaoa committed
157
            accuracy: accuracy,
Lijiao's avatar
Lijiao committed
158
159
            latestAccuracy: this.latestAccuracy,
            formattedLatestAccuracy: this.formatLatestAccuracy(),
160
            accDictionary: this.acc
161
162
163
164
165
166
167
168
169
170
171
172
        };
    }

    get key(): number {
        return this.info.sequenceId;
    }

    get sequenceId(): number {
        return this.info.sequenceId;
    }

    get id(): string {
J-shang's avatar
J-shang committed
173
        return this.info.trialJobId;
174
175
176
177
    }

    get duration(): number {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
178
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
179
180
181
182
183
184
185
186
        return (endTime - this.info.startTime!) / 1000;
    }

    get status(): string {
        return this.info.status;
    }

    get acc(): FinalType | undefined {
Lijiaoa's avatar
Lijiaoa committed
187
188
189
        if (this.info === undefined) {
            return undefined;
        }
190
191
192
193
        return getFinal(this.info.finalMetricData);
    }

    get description(): Parameters {
chicm-ms's avatar
chicm-ms committed
194
        const ret: Parameters = {
195
196
            parameters: {},
            intermediate: [],
197
198
199
200
201
202
203
204
205
206
207
208
            multiProgress: 1
        };
        const tempHyper = this.info.hyperParameters;
        if (tempHyper !== undefined) {
            const getPara = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            ret.multiProgress = tempHyper.length;
            if (typeof getPara === 'string') {
                ret.parameters = JSON.parse(getPara);
            } else {
                ret.parameters = getPara;
            }
        } else {
209
            ret.parameters = { error: "This trial's parameters are not available." };
210
211
212
213
214
        }
        if (this.info.logPath !== undefined) {
            ret.logPath = this.info.logPath;
        }

215
        const mediate: number[] = [];
216
        for (const items of this.intermediateMetrics) {
chicm-ms's avatar
chicm-ms committed
217
218
            if (typeof parseMetrics(items.data) === 'object') {
                mediate.push(parseMetrics(items.data).default);
219
            } else {
chicm-ms's avatar
chicm-ms committed
220
                mediate.push(parseMetrics(items.data));
221
222
223
224
225
226
            }
        }
        ret.intermediate = mediate;
        return ret;
    }

227
    public parameters(axes: MultipleAxes): Map<SingleAxis, any> {
228
        const ret = new Map<SingleAxis, any>(Array.from(axes.axes.values()).map(k => [k, null]));
Lijiaoa's avatar
Lijiaoa committed
229
        if (this.info === undefined || this.info.hyperParameters === undefined) {
230
            throw ret;
231
        } else {
Lijiaoa's avatar
Lijiaoa committed
232
            const tempHyper = this.info.hyperParameters;
233
234
235
236
            let params = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
            if (typeof params === 'string') {
                params = JSON.parse(params);
            }
237
            const [updated, unexpectedEntries] = inferTrialParameters(params, axes);
238
239
240
            if (unexpectedEntries.size) {
                throw unexpectedEntries;
            }
241
242
243
244
            for (const [k, v] of updated) {
                ret.set(k, v);
            }
            return ret;
245
246
247
248
        }
    }

    public metrics(space: MultipleAxes): Map<SingleAxis, any> {
249
250
        // set default value: null
        const ret = new Map<SingleAxis, any>(Array.from(space.axes.values()).map(k => [k, null]));
251
252
253
254
255
256
257
258
        const unexpectedEntries = new Map<string, any>();
        if (this.acc === undefined) {
            return ret;
        }
        const acc = typeof this.acc === 'number' ? { default: this.acc } : this.acc;
        Object.entries(acc).forEach(item => {
            const [k, v] = item;
            const column = space.axes.get(k);
259

260
261
262
263
264
265
266
267
268
269
270
271
            if (column !== undefined) {
                ret.set(column, v);
            } else {
                unexpectedEntries.set(k, v);
            }
        });
        if (unexpectedEntries.size) {
            throw unexpectedEntries;
        }
        return ret;
    }

272
273
274
275
    get color(): string | undefined {
        return undefined;
    }

276
    public finalKeys(): string[] {
277
        if (this.acc !== undefined) {
Lijiaoa's avatar
Lijiaoa committed
278
279
280
281
            return Object.keys(this.acc);
        } else {
            return [];
        }
282
283
    }

284
285
286
    /* table obj end */

    public initialized(): boolean {
287
        return Boolean(this.infoField);
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    }

    public updateMetrics(metrics: MetricDataRecord[]): boolean {
        // parameter `metrics` must contain all known metrics of this trial
        this.metricsInitialized = true;
        const prevMetricCnt = this.intermediates.length + (this.final ? 1 : 0);
        if (metrics.length <= prevMetricCnt) {
            return false;
        }
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                this.intermediates[metric.sequence] = metric;
            } else {
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return true;
    }

    public updateLatestMetrics(metrics: MetricDataRecord[]): boolean {
        // this method is effectively identical to `updateMetrics`, but has worse performance
        this.metricsInitialized = true;
        let updated = false;
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                updated = updated || !this.intermediates[metric.sequence];
                this.intermediates[metric.sequence] = metric;
            } else {
                updated = updated || !this.final;
                this.final = metric;
                this.finalAcc = metricAccuracy(metric);
            }
        }
        return updated;
    }

    public updateTrialJobInfo(trialJobInfo: TrialJobInfo): boolean {
326
        const same = this.infoField && this.infoField.status === trialJobInfo.status;
327
328
329
330
331
332
333
334
        this.infoField = trialJobInfo;
        if (trialJobInfo.finalMetricData) {
            this.final = trialJobInfo.finalMetricData[trialJobInfo.finalMetricData.length - 1];
            this.finalAcc = metricAccuracy(this.final);
        }
        return !same;
    }

Lijiaoa's avatar
Lijiaoa committed
335
    private renderNumber(val: any): string {
336
        if (typeof val === 'number') {
Lijiaoa's avatar
Lijiaoa committed
337
338
            if (isNaNorInfinity(val)) {
                return `${val}`; // show 'NaN' or 'Infinity'
339
            } else {
340
341
342
343
344
                if (this.accuracy === undefined) {
                    return `${formatAccuracy(val)} (LATEST)`;
                } else {
                    return `${formatAccuracy(val)} (FINAL)`;
                }
345
            }
346
        } else {
Lijiaoa's avatar
Lijiaoa committed
347
348
349
350
351
            // show other types, such as {tensor: {data: }}
            return JSON.stringify(val);
        }
    }

352
353
    public formatLatestAccuracy(): string {
        // TODO: this should be private
354
        if (this.status === 'SUCCEEDED') {
355
            return this.accuracy === undefined ? '--' : this.renderNumber(this.accuracy);
Lijiaoa's avatar
Lijiaoa committed
356
357
358
359
360
        } else {
            if (this.accuracy !== undefined) {
                return this.renderNumber(this.accuracy);
            } else if (this.intermediates.length === 0) {
                return '--';
361
            } else {
Lijiaoa's avatar
Lijiaoa committed
362
363
364
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
                const latest = this.intermediates[this.intermediates.length - 1]!;
                return this.renderNumber(metricAccuracy(latest));
365
            }
366
367
368
369
370
        }
    }
}

export { Trial };