"examples/pretrain_gpt_switch.sh" did not exist on "a0bea42531bae13f9a4aa3e2c08c1aa30a66c6c5"
Compare.tsx 8.63 KB
Newer Older
Lijiao's avatar
Lijiao committed
1
import * as React from 'react';
2
import { renderToString } from 'react-dom/server';
3
import { Stack, Modal, IconButton, IDragOptions, ContextualMenu } from '@fluentui/react';
Lijiao's avatar
Lijiao committed
4
import ReactEcharts from 'echarts-for-react';
5
import { TooltipForIntermediate, TableObj, SingleAxis } from '../../static/interface';
6
import { contentStyles, iconButtonStyles } from '../buttons/ModalTheme';
Lijiao's avatar
Lijiao committed
7
import '../../static/style/compare.scss';
8
9
10
11
12
13
import { convertDuration, parseMetrics } from '../../static/function';
import { EXPERIMENT, TRIALS } from '../../static/datamodel';

function _getWebUIWidth(): number {
    return window.innerWidth;
}
Lijiaoa's avatar
Lijiaoa committed
14
15
16
17
18
19

const dragOptions: IDragOptions = {
    moveMenuItemText: 'Move',
    closeMenuItemText: 'Close',
    menu: ContextualMenu
};
Lijiao's avatar
Lijiao committed
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
// TODO: this should be refactored to the common modules
// copied from trial.ts
function _parseIntermediates(trial: TableObj): number[] {
    const intermediates: number[] = [];
    for (const metric of trial.intermediates) {
        if (metric === undefined) {
            break;
        }
        const parsedMetric = parseMetrics(metric.data);
        if (typeof parsedMetric === 'object') {
            // TODO: should handle more types of metric keys
            intermediates.push(parsedMetric.default);
        } else {
            intermediates.push(parsedMetric);
        }
    }
    return intermediates;
}

interface Item {
    id: string;
    sequenceId: number;
    duration: string;
    parameters: Map<string, any>;
    metrics: Map<string, any>;
    intermediates: number[];
}

Lijiao's avatar
Lijiao committed
49
interface CompareProps {
50
51
52
53
    trials: TableObj[];
    title: string;
    showDetails: boolean;
    onHideDialog: () => void;
Lijiao's avatar
Lijiao committed
54
55
56
57
58
59
60
}

class Compare extends React.Component<CompareProps, {}> {
    constructor(props: CompareProps) {
        super(props);
    }

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    private _generateTooltipSummary(row: Item, metricKey: string): string {
        return renderToString(
            <div className='tooldetailAccuracy'>
                <div>Trial ID: {row.id}</div>
                <div>Default metric: {row.metrics.get(metricKey) || 'N/A'}</div>
            </div>
        );
    }

    private _intermediates(items: Item[], metricKey: string): React.ReactNode {
        // Precondition: make sure `items` is not empty
        const xAxisMax = Math.max(...items.map(item => item.intermediates.length));
        const xAxis = Array(xAxisMax)
            .fill(0)
            .map((_, i) => i + 1); // [1, 2, 3, ..., xAxisMax]
        const dataForEchart = items.map(item => ({
            name: item.id,
            data: item.intermediates,
            type: 'line'
        }));
        const legend = dataForEchart.map(item => item.name);
Lijiao's avatar
Lijiao committed
82
83
84
85
        const option = {
            tooltip: {
                trigger: 'item',
                enterable: true,
86
                position: (point: number[], data: TooltipForIntermediate): [number, number] => {
Lijiao's avatar
Lijiao committed
87
88
89
90
91
92
                    if (data.dataIndex < length / 2) {
                        return [point[0], 80];
                    } else {
                        return [point[0] - 300, 80];
                    }
                },
93
94
95
                formatter: (data: TooltipForIntermediate): string => {
                    const item = items.find(k => k.id === data.seriesName) as Item;
                    return this._generateTooltipSummary(item, metricKey);
Lijiao's avatar
Lijiao committed
96
97
98
99
100
101
102
103
                }
            },
            grid: {
                left: '5%',
                top: 40,
                containLabel: true
            },
            legend: {
104
105
                type: 'scroll',
                right: 40,
106
107
                left: legend.length > 6 ? 80 : null,
                data: legend
Lijiao's avatar
Lijiao committed
108
109
110
111
112
113
114
115
            },
            xAxis: {
                type: 'category',
                boundaryGap: false,
                data: xAxis
            },
            yAxis: {
                type: 'value',
116
117
                name: 'Metric',
                scale: true
Lijiao's avatar
Lijiao committed
118
            },
119
            series: dataForEchart
Lijiao's avatar
Lijiao committed
120
121
122
123
124
125
126
127
        };
        return (
            <ReactEcharts
                option={option}
                style={{ width: '100%', height: 418, margin: '0 auto' }}
                notMerge={true} // update now
            />
        );
128
    }
Lijiao's avatar
Lijiao committed
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    private _renderRow(
        key: string,
        rowName: string,
        className: string,
        items: Item[],
        formatter: (item: Item) => string
    ): React.ReactNode {
        return (
            <tr key={key}>
                <td className='column'>{rowName}</td>
                {items.map(item => (
                    <td className={className} key={item.id}>
                        {formatter(item)}
                    </td>
                ))}
            </tr>
        );
    }
148

149
150
151
152
153
154
155
156
157
158
159
160
161
162
    private _overlapKeys(s: Map<string, any>[]): string[] {
        // Calculate the overlapped keys for multiple
        const intersection: string[] = [];
        for (const i of s[0].keys()) {
            let inAll = true;
            for (const t of s) {
                if (!Array.from(t.keys()).includes(i)) {
                    inAll = false;
                    break;
                }
            }
            if (inAll) {
                intersection.push(i);
            }
Lijiao's avatar
Lijiao committed
163
        }
164
165
166
167
168
169
170
171
        return intersection;
    }

    // render table column ---
    private _columns(items: Item[]): React.ReactNode {
        // Precondition: make sure `items` is not empty
        const width = _getWebUIWidth();
        let scrollClass: string = '';
172
        if (width > 1200) {
173
            scrollClass = items.length > 3 ? 'flex' : '';
174
        } else if (width < 700) {
175
            scrollClass = items.length > 1 ? 'flex' : '';
176
        } else {
177
            scrollClass = items.length > 2 ? 'flex' : '';
178
        }
179
180
        const parameterKeys = this._overlapKeys(items.map(item => item.parameters));
        const metricKeys = this._overlapKeys(items.map(item => item.metrics));
Lijiao's avatar
Lijiao committed
181
        return (
182
            <table className={`compare-modal-table ${scrollClass}`}>
Lijiao's avatar
Lijiao committed
183
                <tbody>
184
185
186
187
188
189
190
191
192
                    {this._renderRow('id', 'ID', 'value idList', items, item => item.id)}
                    {this._renderRow('trialnum', 'Trial No.', 'value', items, item => item.sequenceId.toString())}
                    {this._renderRow('duration', 'Duration', 'value', items, item => item.duration)}
                    {parameterKeys.map(k =>
                        this._renderRow(`space_${k}`, k, 'value', items, item => item.parameters.get(k))
                    )}
                    {metricKeys.map(k =>
                        this._renderRow(`metrics_${k}`, `Metric: ${k}`, 'value', items, item => item.metrics.get(k))
                    )}
Lijiao's avatar
Lijiao committed
193
194
195
196
197
                </tbody>
            </table>
        );
    }

198
    render(): React.ReactNode {
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        const { onHideDialog, trials, title, showDetails } = this.props;
        const flatten = (m: Map<SingleAxis, any>): Map<string, any> => {
            return new Map(Array.from(m).map(([key, value]) => [key.baseName, value]));
        };
        const inferredSearchSpace = TRIALS.inferredSearchSpace(EXPERIMENT.searchSpaceNew);
        const items: Item[] = trials.map(trial => ({
            id: trial.id,
            sequenceId: trial.sequenceId,
            duration: convertDuration(trial.duration),
            parameters: flatten(trial.parameters(inferredSearchSpace)),
            metrics: flatten(trial.metrics(TRIALS.inferredMetricSpace())),
            intermediates: _parseIntermediates(trial)
        }));
        const metricKeys = this._overlapKeys(items.map(item => item.metrics));
        const defaultMetricKey = !metricKeys || metricKeys.includes('default') ? 'default' : metricKeys[0];
Lijiao's avatar
Lijiao committed
214
215
216

        return (
            <Modal
217
218
                isOpen={true}
                containerClassName={contentStyles.container}
219
                className='compare-modal'
Lijiaoa's avatar
Lijiaoa committed
220
221
                allowTouchBodyScroll={true}
                dragOptions={dragOptions}
222
                onDismiss={onHideDialog}
Lijiao's avatar
Lijiao committed
223
            >
224
225
                <div>
                    <div className={contentStyles.header}>
226
                        <span>{title}</span>
227
228
229
                        <IconButton
                            styles={iconButtonStyles}
                            iconProps={{ iconName: 'Cancel' }}
230
                            ariaLabel='Close popup modal'
231
                            onClick={onHideDialog}
232
233
                        />
                    </div>
234
                    <Stack className='compare-modal-intermediate'>
235
                        {this._intermediates(items, defaultMetricKey)}
236
                        <Stack className='compare-yAxis'># Intermediate result</Stack>
237
                    </Stack>
238
                    {showDetails && <Stack>{this._columns(items)}</Stack>}
239
                </div>
Lijiao's avatar
Lijiao committed
240
241
242
243
244
245
            </Modal>
        );
    }
}

export default Compare;