tools_needleinahaystack.py 4.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import argparse

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap


class CDMEDataset():

    @staticmethod
12
13
    def visualize(path: str, dataset_length: str):
        for file_path in path:
14
            df = pd.read_csv(file_path)
15
16
17

            df['Context Length'] = df['dataset'].apply(
                lambda x: int(x.split('Length')[1].split('Depth')[0]))
18
            df['Document Depth'] = df['dataset'].apply(
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
                lambda x: float(x.split('Depth')[1].split('_')[0]))

            # Exclude 'Context Length' and 'Document Depth' columns
            model_columns = [
                col for col in df.columns
                if col not in ['Context Length', 'Document Depth']
            ]

            for model_name in model_columns[4:]:
                model_df = df[['Document Depth', 'Context Length',
                               model_name]].copy()
                model_df.rename(columns={model_name: 'Score'}, inplace=True)

                # Create pivot table
                pivot_table = pd.pivot_table(model_df,
                                             values='Score',
                                             index=['Document Depth'],
                                             columns=['Context Length'],
                                             aggfunc='mean')

                # Calculate mean scores
                mean_scores = pivot_table.mean().values

                # Calculate overall score
                overall_score = mean_scores.mean()

                # Create heatmap and line plot
46
                plt.figure(figsize=(15.5, 8))
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
                ax = plt.gca()
                cmap = LinearSegmentedColormap.from_list(
                    'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F'])

                # Draw heatmap
                sns.heatmap(pivot_table,
                            cmap=cmap,
                            ax=ax,
                            cbar_kws={'label': 'Score'},
                            vmin=0,
                            vmax=100)

                # Set line plot data
                x_data = [i + 0.5 for i in range(len(mean_scores))]
                y_data = mean_scores

                # Create twin axis for line plot
                ax2 = ax.twinx()
                # Draw line plot
                ax2.plot(x_data,
                         y_data,
                         color='white',
                         marker='o',
                         linestyle='-',
                         linewidth=2,
                         markersize=8,
                         label='Average Depth Score')
                # Set y-axis range
                ax2.set_ylim(0, 100)

                # Hide original y-axis ticks and labels
                ax2.set_yticklabels([])
                ax2.set_yticks([])

                # Add legend
                ax2.legend(loc='upper left')

                # Set chart title and labels
85
86
87
                ax.set_title(f'{model_name} {dataset_length} Context '
                             'Performance\nFact Retrieval Across '
                             'Context Lengths ("Needle In A Haystack")')
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
                ax.set_xlabel('Token Limit')
                ax.set_ylabel('Depth Percent')
                ax.set_xticklabels(pivot_table.columns.values, rotation=45)
                ax.set_yticklabels(pivot_table.index.values, rotation=0)
                # Add overall score as a subtitle
                plt.text(0.5,
                         -0.13, f'Overall Score for {model_name}: '
                         f'{overall_score:.2f}',
                         ha='center',
                         va='center',
                         transform=ax.transAxes,
                         fontsize=13)

                # Save heatmap as PNG
                png_file_path = file_path.replace('.csv', f'_{model_name}.png')
103
104
105
                plt.tight_layout()
                plt.subplots_adjust(right=1)
                plt.draw()
106
107
108
109
110
111
112
                plt.savefig(png_file_path)
                plt.show()

                plt.close()  # Close figure to prevent memory leaks

                # Print saved PNG file path
                print(f'Heatmap for {model_name} saved as: {png_file_path}')
113
114
115


def main():
116
117
    parser = argparse.ArgumentParser(description='Generate NeedleInAHaystack'
                                     'Test Plots')
118

119
    parser.add_argument('--path',
120
121
122
                        nargs='*',
                        default=['path/to/your/result.csv'],
                        help='Paths to CSV files for visualization')
123
124
125
126
    parser.add_argument('--dataset_length',
                        default='8K',
                        type=str,
                        help='Dataset_length for visualization')
127
128
    args = parser.parse_args()

129
130
131
132
    if not args.path:
        print("Error: '--path' is required for visualization.")
        exit(1)
    CDMEDataset.visualize(args.path, args.dataset_length)
133
134
135
136


if __name__ == '__main__':
    main()