"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "f202a956612deeedcd44379e53e9d1b2038e8140"
plot_csv_file.py 4.33 KB
Newer Older
1
2
3
4
5
import csv
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Optional

6
import matplotlib.pyplot as plt
7
import numpy as np
8
from matplotlib.ticker import ScalarFormatter
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

from transformers import HfArgumentParser


@dataclass
class PlotArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    csv_file: str = field(metadata={"help": "The csv file to plot."},)
    plot_along_batch: bool = field(
        default=False,
        metadata={"help": "Whether to plot along batch size or sequence lengh. Defaults to sequence length."},
    )
    is_time: bool = field(
        default=False,
        metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."},
    )
28
29
30
    no_log_scale: bool = field(
        default=False, metadata={"help": "Disable logarithmic scale when plotting"},
    )
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    is_train: bool = field(
        default=False,
        metadata={
            "help": "Whether the csv file has training results or inference results. Defaults to inference results."
        },
    )
    figure_png_file: Optional[str] = field(
        default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
    )


class Plot:
    def __init__(self, args):
        self.args = args
        self.result_dict = defaultdict(lambda: dict(bsz=[], seq_len=[], result={}))

        with open(self.args.csv_file, newline="") as csv_file:
            reader = csv.DictReader(csv_file)
            for row in reader:
                model_name = row["model"]
                self.result_dict[model_name]["bsz"].append(int(row["batch_size"]))
                self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"]))
                self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = row[
                    "result"
                ]

    def plot(self):
        fig, ax = plt.subplots()
        title_str = "Time usage" if self.args.is_time else "Memory usage"
        title_str = title_str + " for training" if self.args.is_train else title_str + " for inference"

62
63
64
65
66
67
68
69
        if not self.args.no_log_scale:
            # set logarithm scales
            ax.set_xscale("log")
            ax.set_yscale("log")

        for axis in [ax.xaxis, ax.yaxis]:
            axis.set_major_formatter(ScalarFormatter())

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        for model_name in self.result_dict.keys():
            batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
            sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
            results = self.result_dict[model_name]["result"]

            (x_axis_array, inner_loop_array) = (
                (batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
            )

            for inner_loop_value in inner_loop_array:
                if self.args.plot_along_batch:
                    y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int)
                else:
                    y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32)

                (x_axis_label, inner_loop_label) = (
                    ("batch_size", "sequence_length in #tokens")
                    if self.args.plot_along_batch
                    else ("sequence_length in #tokens", "batch_size")
                )

                x_axis_array = np.asarray(x_axis_array, np.int)
                plt.scatter(x_axis_array, y_axis_array, label=f"{model_name} - {inner_loop_label}: {inner_loop_value}")
                plt.plot(x_axis_array, y_axis_array, "--")

            title_str += f" {model_name} vs."

        title_str = title_str[:-4]
        y_axis_label = "Time in s" if self.args.is_time else "Memory in MB"

        # plot
        plt.title(title_str)
        plt.xlabel(x_axis_label)
        plt.ylabel(y_axis_label)
        plt.legend()

        if self.args.figure_png_file is not None:
            plt.savefig(self.args.figure_png_file)
        else:
            plt.show()


def main():
    parser = HfArgumentParser(PlotArguments)
    plot_args = parser.parse_args_into_dataclasses()[0]
    plot = Plot(args=plot_args)
    plot.plot()


if __name__ == "__main__":
    main()