searchspace.ts 6.37 KB
Newer Older
1
2
import { SingleAxis, MultipleAxes } from '../interface';
import { Trial } from './trial';
3
import { SUPPORTED_SEARCH_SPACE_TYPE } from '../const';
Lijiaoa's avatar
Lijiaoa committed
4
import { formatComplexTypeValue } from '../function';
5
6

function fullNameJoin(prefix: string, name: string): string {
7
    return prefix ? prefix + '/' + name : name;
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
}

class NumericAxis implements SingleAxis {
    min: number = 0;
    max: number = 0;
    type: string;
    baseName: string;
    fullName: string;
    scale: 'log' | 'linear';
    nested = false;

    constructor(baseName: string, fullName: string, type: string, value: any) {
        this.baseName = baseName;
        this.fullName = fullName;
        this.type = type;
        this.scale = type.includes('log') ? 'log' : 'linear';
        if (type === 'randint') {
            this.min = value[0];
            this.max = value[1] - 1;
        } else if (type.includes('uniform')) {
            this.min = value[0];
            this.max = value[1];
        } else if (type.includes('normal')) {
            const [mu, sigma] = [value[0], value[1]];
            this.min = mu - 4 * sigma;
            this.max = mu + 4 * sigma;
            if (this.scale === 'log') {
                this.min = Math.exp(this.min);
                this.max = Math.exp(this.max);
            }
        }
    }

    get domain(): [number, number] {
        return [this.min, this.max];
    }
}

class SimpleOrdinalAxis implements SingleAxis {
    type: string;
    baseName: string;
    fullName: string;
    scale: 'ordinal' = 'ordinal';
    domain: any[];
    nested = false;
    constructor(baseName: string, fullName: string, type: string, value: any) {
        this.baseName = baseName;
        this.fullName = fullName;
        this.type = type;
Lijiaoa's avatar
Lijiaoa committed
57
        this.domain = Array.from(value).map(formatComplexTypeValue);
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    }
}

class NestedOrdinalAxis implements SingleAxis {
    type: string;
    baseName: string;
    fullName: string;
    scale: 'ordinal' = 'ordinal';
    domain = new Map<string, MultipleAxes>();
    nested = true;
    constructor(baseName: any, fullName: string, type: any, value: any) {
        this.baseName = baseName;
        this.fullName = fullName;
        this.type = type;
        for (const v of value) {
            // eslint-disable-next-line @typescript-eslint/no-use-before-define
            this.domain.set(v._name, new SearchSpace(v._name, fullNameJoin(fullName, v._name), v));
        }
    }
}

export class SearchSpace implements MultipleAxes {
    axes = new Map<string, SingleAxis>();
    baseName: string;
    fullName: string;

    constructor(baseName: string, fullName: string, searchSpaceSpec: any) {
        this.baseName = baseName;
        this.fullName = fullName;
        if (searchSpaceSpec === undefined) {
            return;
        }
90
91
92
        Object.entries(searchSpaceSpec).forEach(item => {
            const key = item[0],
                spec = item[1] as any;
93
94
95
96
97
98
            if (key === '_name') {
                return;
            } else if (['choice', 'layer_choice', 'input_choice'].includes(spec._type)) {
                // ordinal types
                if (spec._value && typeof spec._value[0] === 'object') {
                    // nested dimension
99
100
101
102
                    this.axes.set(
                        key,
                        new NestedOrdinalAxis(key, fullNameJoin(fullName, key), spec._type, spec._value)
                    );
103
                } else {
104
105
106
107
                    this.axes.set(
                        key,
                        new SimpleOrdinalAxis(key, fullNameJoin(fullName, key), spec._type, spec._value)
                    );
108
109
110
111
112
113
114
                }
            } else if (SUPPORTED_SEARCH_SPACE_TYPE.includes(spec._type)) {
                this.axes.set(key, new NumericAxis(key, fullName + key, spec._type, spec._value));
            }
        });
    }

115
    static inferFromTrials(searchSpace: SearchSpace, trials: Trial[]): SearchSpace {
116
117
118
119
120
121
122
123
124
125
126
        const newSearchSpace = new SearchSpace(searchSpace.baseName, searchSpace.fullName, undefined);
        for (const [k, v] of searchSpace.axes) {
            newSearchSpace.axes.set(k, v);
        }
        // Add axis inferred from trials columns
        const addingColumns = new Map<string, any[]>();
        for (const trial of trials) {
            try {
                trial.parameters(searchSpace);
            } catch (unexpectedEntries) {
                // eslint-disable-next-line no-console
Lijiaoa's avatar
Lijiaoa committed
127
                console.warn(unexpectedEntries);
128
129
130
131
132
133
134
135
136
137
138
139
                for (const [k, v] of unexpectedEntries as Map<string, any>) {
                    const column = addingColumns.get(k);
                    if (column === undefined) {
                        addingColumns.set(k, [v]);
                    } else {
                        column.push(v);
                    }
                }
            }
        }
        addingColumns.forEach((value, key) => {
            if (value.every(v => typeof v === 'number')) {
140
141
142
143
                newSearchSpace.axes.set(
                    key,
                    new NumericAxis(key, key, 'uniform', [Math.min(...value), Math.max(...value)])
                );
144
145
146
147
148
149
150
151
152
153
154
155
156
            } else {
                newSearchSpace.axes.set(key, new SimpleOrdinalAxis(key, key, 'choice', new Set(value).values()));
            }
        });
        return newSearchSpace;
    }
}

export class MetricSpace implements MultipleAxes {
    axes = new Map<string, SingleAxis>();
    baseName = '';
    fullName = '';

157
    constructor(trials: Trial[]) {
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        const columns = new Map<string, any[]>();
        for (const trial of trials) {
            if (trial.acc === undefined) {
                continue;
            }
            // TODO: handle more than number and object
            const acc = typeof trial.acc === 'number' ? { default: trial.acc } : trial.acc;
            Object.entries(acc).forEach(item => {
                const [k, v] = item;
                const column = columns.get(k);
                if (column === undefined) {
                    columns.set(k, [v]);
                } else {
                    column.push(v);
                }
            });
        }
        columns.forEach((value, key) => {
            if (value.every(v => typeof v === 'number')) {
                this.axes.set(key, new NumericAxis(key, key, 'uniform', [Math.min(...value), Math.max(...value)]));
            } else {
Lijiaoa's avatar
Lijiaoa committed
179
                this.axes.set(key, new SimpleOrdinalAxis(key, key, 'choice', value));
180
181
182
183
            }
        });
    }
}