trial.ts 11.1 KB
Newer Older
1
import { MetricDataRecord, TrialJobInfo, TableRecord, FinalType, MultipleAxes, SingleAxis } from '../interface';
2
3
4
5
6
7
8
import {
    getFinal,
    formatAccuracy,
    metricAccuracy,
    parseMetrics,
    isArrayType,
    isNaNorInfinity,
9
10
    formatComplexTypeValue,
    reformatRetiariiParameter
11
} from '../function';
12

13
14
15
16
17
18
19
/**
 * Get a structured representation of parameters
 * @param paramObj Parameters object
 * @param space All axes from search space (or sub search space)
 * @param prefix Current namespace (to make full name for unexpected entries)
 * @returns Parsed structured parameters and unexpected entries
 */
20
21
22
23
24
function inferTrialParameters(
    paramObj: object,
    space: MultipleAxes,
    prefix: string = ''
): [Map<SingleAxis, any>, Map<string, any>] {
25
    const latestedParamObj: object = reformatRetiariiParameter(paramObj);
26
27
    const parameters = new Map<SingleAxis, any>();
    const unexpectedEntries = new Map<string, any>();
28
    for (const [k, v] of Object.entries(latestedParamObj)) {
29
30
        // prefix can be a good fallback when corresponding item is not found in namespace
        const axisKey = space.axes.get(k);
31
        if (prefix && k === '_name') continue;
32
33
34
35
36
37
38
39
40
41
42
        if (axisKey !== undefined) {
            if (typeof v === 'object' && v._name !== undefined && axisKey.nested) {
                // nested entry
                parameters.set(axisKey, v._name);
                const subSpace = axisKey.domain.get(v._name);
                if (subSpace !== undefined) {
                    const [subParams, subUnexpected] = inferTrialParameters(v, subSpace, prefix + k + '/');
                    subParams.forEach((v, k) => parameters.set(k, v));
                    subUnexpected.forEach((v, k) => unexpectedEntries.set(k, v));
                }
            } else {
Lijiaoa's avatar
Lijiaoa committed
43
                parameters.set(axisKey, formatComplexTypeValue(v));
44
45
            }
        } else {
Lijiaoa's avatar
Lijiaoa committed
46
            unexpectedEntries.set(prefix + k, formatComplexTypeValue(v));
47
48
49
50
51
        }
    }
    return [parameters, unexpectedEntries];
}

52
class Trial {
53
54
    private metricsInitialized: boolean = false;
    private infoField: TrialJobInfo | undefined;
55
    public accuracy: number | undefined; // trial default metric val: number value or undefined
56
    public intermediates: (MetricDataRecord | undefined)[] = [];
57
58
59
60
61
62
63
64
65
66
67
68

    constructor(info?: TrialJobInfo, metrics?: MetricDataRecord[]) {
        this.infoField = info;
        if (metrics) {
            this.updateMetrics(metrics);
        }
    }

    public compareAccuracy(otherTrial: Trial): number | undefined {
        if (!this.sortable || !otherTrial.sortable) {
            return undefined;
        }
Lijiao's avatar
Lijiao committed
69
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
70
        return this.accuracy! - otherTrial.accuracy!;
71
72
73
    }

    get info(): TrialJobInfo {
Lijiao's avatar
Lijiao committed
74
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
75
76
77
        return this.infoField!;
    }

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    get sequenceId(): number {
        return this.info.sequenceId;
    }

    get id(): string {
        return this.info.trialJobId;
    }

    get duration(): number {
        const endTime = this.info.endTime || new Date().getTime();
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
        return (endTime - this.info.startTime!) / 1000;
    }

    get status(): string {
        return this.info.status;
    }

    get parameter(): object {
        return JSON.parse(this.info.hyperParameters![0]).parameters;
    }

    // return dict final result: {default: xxx...}
    get acc(): FinalType | undefined {
        if (this.info === undefined) {
            return undefined;
104
        }
105
        return getFinal(this.info.finalMetricData);
106
107
    }

108
109
110
111
112
113
114
115
116
    public parameters(axes: MultipleAxes): Map<SingleAxis, any> {
        const ret = new Map<SingleAxis, any>(Array.from(axes.axes.values()).map(k => [k, null]));
        if (this.info === undefined || this.info.hyperParameters === undefined) {
            throw ret;
        } else {
            let params = JSON.parse(this.info.hyperParameters[0]).parameters;
            if (typeof params === 'string') {
                params = JSON.parse(params);
            }
117
118
119
120
121
122
            // for hpo experiment: search space choice value is None, and it shows null
            for (const [key, value] of Object.entries(params)) {
                if (Object.is(null, value)) {
                    params[key] = 'null';
                }
            }
123
124
125
126
127
128
129
130
131
            const [updated, unexpectedEntries] = inferTrialParameters(params, axes);
            if (unexpectedEntries.size) {
                throw unexpectedEntries;
            }
            for (const [k, v] of updated) {
                ret.set(k, v);
            }
            return ret;
        }
132
133
134
    }

    get sortable(): boolean {
135
        return this.metricsInitialized && this.accuracy !== undefined && isFinite(this.accuracy);
136
137
    }

Lijiao's avatar
Lijiao committed
138
139
140
141
142
143
    get latestAccuracy(): number | undefined {
        if (this.accuracy !== undefined) {
            return this.accuracy;
        } else if (this.intermediates.length > 0) {
            const temp = this.intermediates[this.intermediates.length - 1];
            if (temp !== undefined) {
144
145
                if (isArrayType(parseMetrics(temp.data))) {
                    return undefined;
146
147
                } else if (
                    typeof parseMetrics(temp.data) === 'object' &&
148
                    // eslint-disable-next-line no-prototype-builtins
149
150
                    parseMetrics(temp.data).hasOwnProperty('default')
                ) {
151
                    return parseMetrics(temp.data).default;
152
                } else if (typeof parseMetrics(temp.data) === 'number') {
153
154
                    return parseMetrics(temp.data);
                }
Lijiao's avatar
Lijiao committed
155
156
157
158
159
160
161
            } else {
                return undefined;
            }
        } else {
            return undefined;
        }
    }
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    get accuracyNumberTypeDictKeys(): string[] {
        let accuracyTypeList: string[] = [];

        if (this.acc !== undefined) {
            for (const [item, value] of Object.entries(this.acc)) {
                if (typeof value === 'number') {
                    accuracyTypeList.push(item);
                }
            }
        } else {
            accuracyTypeList = ['default'];
        }

        return accuracyTypeList;
    }

179
180
181
182
    /* table obj start */

    get tableRecord(): TableRecord {
        const endTime = this.info.endTime || new Date().getTime();
Lijiao's avatar
Lijiao committed
183
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
184
        const duration = (endTime - this.info.startTime!) / 1000;
185

186
        return {
187
            _key: this.info.trialJobId,
188
            sequenceId: this.info.sequenceId,
J-shang's avatar
J-shang committed
189
            id: this.info.trialJobId,
Lijiao's avatar
Lijiao committed
190
            // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
191
192
193
194
            startTime: this.info.startTime!,
            endTime: this.info.endTime,
            duration,
            status: this.info.status,
195
            message: this.info.message ?? '--',
196
            intermediateCount: this.intermediates.length,
Lijiao's avatar
Lijiao committed
197
            latestAccuracy: this.latestAccuracy,
198
            _formattedLatestAccuracy: this.formatLatestAccuracy()
199
200
201
        };
    }

202
    public metrics(space: MultipleAxes): Map<SingleAxis, any> {
203
204
        // set default value: null
        const ret = new Map<SingleAxis, any>(Array.from(space.axes.values()).map(k => [k, null]));
205
206
207
208
        const unexpectedEntries = new Map<string, any>();
        if (this.acc === undefined) {
            return ret;
        }
209
        Object.entries(this.acc).forEach(item => {
210
211
            const [k, v] = item;
            const column = space.axes.get(k);
212

213
214
215
216
217
218
219
220
221
222
223
224
            if (column !== undefined) {
                ret.set(column, v);
            } else {
                unexpectedEntries.set(k, v);
            }
        });
        if (unexpectedEntries.size) {
            throw unexpectedEntries;
        }
        return ret;
    }

225
226
227
    /* table obj end */

    public initialized(): boolean {
228
        return Boolean(this.infoField);
229
230
231
232
233
    }

    public updateMetrics(metrics: MetricDataRecord[]): boolean {
        // parameter `metrics` must contain all known metrics of this trial
        this.metricsInitialized = true;
234
        const prevMetricCnt = this.intermediates.length + (this.accuracy ? 1 : 0);
235
236
237
238
239
240
241
        if (metrics.length <= prevMetricCnt) {
            return false;
        }
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                this.intermediates[metric.sequence] = metric;
            } else {
242
                this.accuracy = metricAccuracy(metric);
243
244
245
246
247
248
249
250
251
252
253
254
255
256
            }
        }
        return true;
    }

    public updateLatestMetrics(metrics: MetricDataRecord[]): boolean {
        // this method is effectively identical to `updateMetrics`, but has worse performance
        this.metricsInitialized = true;
        let updated = false;
        for (const metric of metrics) {
            if (metric.type === 'PERIODICAL') {
                updated = updated || !this.intermediates[metric.sequence];
                this.intermediates[metric.sequence] = metric;
            } else {
257
258
                updated = updated || !this.accuracy;
                this.accuracy = metricAccuracy(metric);
259
260
261
262
263
264
            }
        }
        return updated;
    }

    public updateTrialJobInfo(trialJobInfo: TrialJobInfo): boolean {
265
        const same = this.infoField && this.infoField.status === trialJobInfo.status;
266
267
        this.infoField = trialJobInfo;
        if (trialJobInfo.finalMetricData) {
268
            this.accuracy = metricAccuracy(trialJobInfo.finalMetricData[0]);
269
270
271
272
        }
        return !same;
    }

273
274
275
276
277
278
279
280
281
    /**
     *
     * @param val trial latest accuracy
     * @returns 0.9(FINAL) or 0.9(LATEST)
     * NaN or Infinity
     * string object such as: '{tensor: {data}}'
     *
     */
    private formatLatestAccuracyToString(val: any): string {
282
        if (typeof val === 'number') {
Lijiaoa's avatar
Lijiaoa committed
283
284
            if (isNaNorInfinity(val)) {
                return `${val}`; // show 'NaN' or 'Infinity'
285
            } else {
286
287
288
289
290
                if (this.accuracy === undefined) {
                    return `${formatAccuracy(val)} (LATEST)`;
                } else {
                    return `${formatAccuracy(val)} (FINAL)`;
                }
291
            }
292
        } else {
Lijiaoa's avatar
Lijiaoa committed
293
294
295
296
297
            // show other types, such as {tensor: {data: }}
            return JSON.stringify(val);
        }
    }

298
299
300
301
302
303
304
305
306
307
    /**
     *
     * @param val trial latest accuracy
     * @returns 0.9(FINAL) or 0.9(LATEST)
     * NaN or Infinity
     * string object such as: '{tensor: {data}}'
     * +1 describe type undefined: --
     *
     */
    private formatLatestAccuracy(): string {
308
        if (this.status === 'SUCCEEDED') {
309
            return this.accuracy === undefined ? '--' : this.formatLatestAccuracyToString(this.accuracy);
Lijiaoa's avatar
Lijiaoa committed
310
311
        } else {
            if (this.accuracy !== undefined) {
312
                return this.formatLatestAccuracyToString(this.accuracy);
Lijiaoa's avatar
Lijiaoa committed
313
314
            } else if (this.intermediates.length === 0) {
                return '--';
315
            } else {
Lijiaoa's avatar
Lijiaoa committed
316
317
                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
                const latest = this.intermediates[this.intermediates.length - 1]!;
318
                return this.formatLatestAccuracyToString(metricAccuracy(latest));
319
            }
320
321
322
323
324
        }
    }
}

export { Trial };