Unverified Commit 79a82cc0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Benchmarks] improve Example Plotter (#5245)

* improve plotting

* better labels

* fix time plot
parent 88d7f96e
import csv import csv
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import List, Optional
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
...@@ -10,6 +10,10 @@ from matplotlib.ticker import ScalarFormatter ...@@ -10,6 +10,10 @@ from matplotlib.ticker import ScalarFormatter
from transformers import HfArgumentParser from transformers import HfArgumentParser
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
@dataclass @dataclass
class PlotArguments: class PlotArguments:
""" """
...@@ -37,6 +41,25 @@ class PlotArguments: ...@@ -37,6 +41,25 @@ class PlotArguments:
figure_png_file: Optional[str] = field( figure_png_file: Optional[str] = field(
default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."}, default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
) )
short_model_names: Optional[List[str]] = list_field(
default=None, metadata={"help": "List of model names that are used instead of the ones in the csv file."}
)
def can_convert_to_int(string):
try:
int(string)
return True
except ValueError:
return False
def can_convert_to_float(string):
try:
float(string)
return True
except ValueError:
return False
class Plot: class Plot:
...@@ -50,9 +73,16 @@ class Plot: ...@@ -50,9 +73,16 @@ class Plot:
model_name = row["model"] model_name = row["model"]
self.result_dict[model_name]["bsz"].append(int(row["batch_size"])) self.result_dict[model_name]["bsz"].append(int(row["batch_size"]))
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"])) self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"]))
self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = row[ if can_convert_to_int(row["result"]):
"result" # value is not None
] self.result_dict[model_name]["result"][
(int(row["batch_size"]), int(row["sequence_length"]))
] = int(row["result"])
elif can_convert_to_float(row["result"]):
# value is not None
self.result_dict[model_name]["result"][
(int(row["batch_size"]), int(row["sequence_length"]))
] = float(row["result"])
def plot(self): def plot(self):
fig, ax = plt.subplots() fig, ax = plt.subplots()
...@@ -67,7 +97,7 @@ class Plot: ...@@ -67,7 +97,7 @@ class Plot:
for axis in [ax.xaxis, ax.yaxis]: for axis in [ax.xaxis, ax.yaxis]:
axis.set_major_formatter(ScalarFormatter()) axis.set_major_formatter(ScalarFormatter())
for model_name in self.result_dict.keys(): for model_name_idx, model_name in enumerate(self.result_dict.keys()):
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"]))) batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"]))) sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
results = self.result_dict[model_name]["result"] results = self.result_dict[model_name]["result"]
...@@ -76,23 +106,33 @@ class Plot: ...@@ -76,23 +106,33 @@ class Plot:
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes) (batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
) )
label_model_name = (
model_name if self.args.short_model_names is None else self.args.short_model_names[model_name_idx]
)
for inner_loop_value in inner_loop_array: for inner_loop_value in inner_loop_array:
if self.args.plot_along_batch: if self.args.plot_along_batch:
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int) y_axis_array = np.asarray(
[results[(x, inner_loop_value)] for x in x_axis_array if (x, inner_loop_value) in results],
dtype=np.int,
)
else: else:
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32) y_axis_array = np.asarray(
[results[(inner_loop_value, x)] for x in x_axis_array if (inner_loop_value, x) in results],
dtype=np.float32,
)
(x_axis_label, inner_loop_label) = ( (x_axis_label, inner_loop_label) = (
("batch_size", "sequence_length in #tokens") ("batch_size", "len") if self.args.plot_along_batch else ("in #tokens", "bsz")
if self.args.plot_along_batch
else ("sequence_length in #tokens", "batch_size")
) )
x_axis_array = np.asarray(x_axis_array, np.int) x_axis_array = np.asarray(x_axis_array, np.int)[: len(y_axis_array)]
plt.scatter(x_axis_array, y_axis_array, label=f"{model_name} - {inner_loop_label}: {inner_loop_value}") plt.scatter(
x_axis_array, y_axis_array, label=f"{label_model_name} - {inner_loop_label}: {inner_loop_value}"
)
plt.plot(x_axis_array, y_axis_array, "--") plt.plot(x_axis_array, y_axis_array, "--")
title_str += f" {model_name} vs." title_str += f" {label_model_name} vs."
title_str = title_str[:-4] title_str = title_str[:-4]
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB" y_axis_label = "Time in s" if self.args.is_time else "Memory in MB"
......
model,batch_size,sequence_length,result
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,8,512,0.2032
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,64,512,1.5279
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,256,512,6.1837
...@@ -74,12 +74,6 @@ class BenchmarkArguments: ...@@ -74,12 +74,6 @@ class BenchmarkArguments:
"help": "Don't use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be used for debugging / testing and on TPU." "help": "Don't use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be used for debugging / testing and on TPU."
}, },
) )
with_lm_head: bool = field(
default=False,
metadata={
"help": "Use model with its language model head (MODEL_WITH_LM_HEAD_MAPPING instead of MODEL_MAPPING)"
},
)
inference_time_csv_file: str = field( inference_time_csv_file: str = field(
default=f"inference_time_{round(time())}.csv", default=f"inference_time_{round(time())}.csv",
metadata={"help": "CSV filename used if saving time results to csv."}, metadata={"help": "CSV filename used if saving time results to csv."},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment