Unverified Commit 96f57c9c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Benchmark] Memory benchmark utils (#4198)



* improve memory benchmarking

* correct typo

* fix current memory

* check torch memory allocated

* better pytorch function

* add total cached gpu memory

* add total gpu required

* improve torch gpu usage

* update memory usage

* finalize memory tracing

* save intermediate benchmark class

* fix conflict

* improve benchmark

* improve benchmark

* finalize

* make style

* improve benchmarking

* correct typo

* make train function more flexible

* fix csv save

* better repr of bytes

* better print

* fix __repr__ bug

* finish plot script

* rename plot file

* delete csv and small improvements

* fix in plot

* fix in plot

* correct usage of timeit

* remove redundant line

* remove redundant line

* fix bug

* add hf parser tests

* add versioning and platform info

* make style

* add gpu information

* ensure backward compatibility

* finish adding all tests

* Update src/transformers/benchmark/benchmark_args.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/benchmark/benchmark_args_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* delete csv files

* fix isort ordering

* add out of memory handling

* add better train memory handling
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent ec4cdfdd
import csv
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import matplotlib.pyplot as plt
from transformers import HfArgumentParser
@dataclass
class PlotArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
csv_file: str = field(metadata={"help": "The csv file to plot."},)
plot_along_batch: bool = field(
default=False,
metadata={"help": "Whether to plot along batch size or sequence lengh. Defaults to sequence length."},
)
is_time: bool = field(
default=False,
metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."},
)
is_train: bool = field(
default=False,
metadata={
"help": "Whether the csv file has training results or inference results. Defaults to inference results."
},
)
figure_png_file: Optional[str] = field(
default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
)
class Plot:
def __init__(self, args):
self.args = args
self.result_dict = defaultdict(lambda: dict(bsz=[], seq_len=[], result={}))
with open(self.args.csv_file, newline="") as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
model_name = row["model"]
self.result_dict[model_name]["bsz"].append(int(row["batch_size"]))
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"]))
self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = row[
"result"
]
def plot(self):
fig, ax = plt.subplots()
title_str = "Time usage" if self.args.is_time else "Memory usage"
title_str = title_str + " for training" if self.args.is_train else title_str + " for inference"
for model_name in self.result_dict.keys():
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
results = self.result_dict[model_name]["result"]
(x_axis_array, inner_loop_array) = (
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
)
plt.xlim(min(x_axis_array), max(x_axis_array))
for inner_loop_value in inner_loop_array:
if self.args.plot_along_batch:
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int)
else:
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32)
ax.set_xscale("log", basex=2)
ax.set_yscale("log", basey=10)
(x_axis_label, inner_loop_label) = (
("batch_size", "sequence_length in #tokens")
if self.args.plot_along_batch
else ("sequence_length in #tokens", "batch_size")
)
x_axis_array = np.asarray(x_axis_array, np.int)
plt.scatter(x_axis_array, y_axis_array, label=f"{model_name} - {inner_loop_label}: {inner_loop_value}")
plt.plot(x_axis_array, y_axis_array, "--")
title_str += f" {model_name} vs."
title_str = title_str[:-4]
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB"
# plot
plt.title(title_str)
plt.xlabel(x_axis_label)
plt.ylabel(y_axis_label)
plt.legend()
if self.args.figure_png_file is not None:
plt.savefig(self.args.figure_png_file)
else:
plt.show()
def main():
parser = HfArgumentParser(PlotArguments)
plot_args = parser.parse_args_into_dataclasses()[0]
plot = Plot(args=plot_args)
plot.plot()
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the library on inference and training """
from transformers import HfArgumentParser, PyTorchBenchmark, PyTorchBenchmarkArguments
def main():
parser = HfArgumentParser(PyTorchBenchmarkArguments)
benchmark_args = parser.parse_args_into_dataclasses()[0]
benchmark = PyTorchBenchmark(args=benchmark_args)
benchmark.run()
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the library on inference and training """
# If checking the tensors placement
# tf.debugging.set_log_device_placement(True)
import argparse
import csv
import logging
import timeit
from time import time
from typing import Callable, List
from transformers import (
AutoConfig,
AutoTokenizer,
MemorySummary,
is_tf_available,
is_torch_available,
start_memory_tracing,
stop_memory_tracing,
)
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModel
if is_torch_available():
import torch
from transformers import AutoModel
input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as
the Director of Hatcheries and Conditioning entered the room, in the
scarcely breathing silence, the absent-minded, soliloquizing hum or
whistle, of absorbed concentration. A troop of newly arrived students,
very young, pink and callow, followed nervously, rather abjectly, at the
Director's heels. Each of them carried a notebook, in which, whenever
the great man spoke, he desperately scribbled. Straight from the
horse's mouth. It was a rare privilege. The D. H. C. for Central London
always made a point of personally conducting his new students round
the various departments.
"Just to give you a general idea," he would explain to them. For of
course some sort of general idea they must have, if they were to do
their work intelligently-though as little of one, if they were to be good
and happy members of society, as possible. For particulars, as every
one knows, make for virtue and happiness; generalities are intellectu-
ally necessary evils. Not philosophers but fret-sawyers and stamp col-
lectors compose the backbone of society.
"To-morrow," he would add, smiling at them with a slightly menacing
geniality, "you'll be settling down to serious work. You won't have time
for generalities. Meanwhile ..."
Meanwhile, it was a privilege. Straight from the horse's mouth into the
notebook. The boys scribbled like mad.
Tall and rather thin but upright, the Director advanced into the room.
He had a long chin and big rather prominent teeth, just covered, when
he was not talking, by his full, floridly curved lips. Old, young? Thirty?
Fifty? Fifty-five? It was hard to say. And anyhow the question didn't
arise; in this year of stability, A. F. 632, it didn't occur to you to ask it.
"I shall begin at the beginning," said the D.H.C. and the more zealous
students recorded his intention in their notebooks: Begin at the begin-
ning. "These," he waved his hand, "are the incubators." And opening
an insulated door he showed them racks upon racks of numbered test-
tubes. "The week's supply of ova. Kept," he explained, "at blood heat;
whereas the male gametes," and here he opened another door, "they
have to be kept at thirty-five instead of thirty-seven. Full blood heat
sterilizes." Rams wrapped in theremogene beget no lambs.
Still leaning against the incubators he gave them, while the pencils
scurried illegibly across the pages, a brief description of the modern
fertilizing process; spoke first, of course, of its surgical introduc-
tion-"the operation undergone voluntarily for the good of Society, not
to mention the fact that it carries a bonus amounting to six months'
salary"; continued with some account of the technique for preserving
the excised ovary alive and actively developing; passed on to a consid-
eration of optimum temperature, salinity, viscosity; referred to the liq-
uor in which the detached and ripened eggs were kept; and, leading
his charges to the work tables, actually showed them how this liquor
was drawn off from the test-tubes; how it was let out drop by drop
onto the specially warmed slides of the microscopes; how the eggs
which it contained were inspected for abnormalities, counted and
transferred to a porous receptacle; how (and he now took them to
watch the operation) this receptacle was immersed in a warm bouillon
containing free-swimming spermatozoa-at a minimum concentration
of one hundred thousand per cubic centimetre, he insisted; and how,
after ten minutes, the container was lifted out of the liquor and its
contents re-examined; how, if any of the eggs remained unfertilized, it
was again immersed, and, if necessary, yet again; how the fertilized
ova went back to the incubators; where the Alphas and Betas re-
mained until definitely bottled; while the Gammas, Deltas and Epsilons
were brought out again, after only thirty-six hours, to undergo Bo-
kanovsky's Process.
"Bokanovsky's Process," repeated the Director, and the students un-
derlined the words in their little notebooks.
One egg, one embryo, one adult-normality. But a bokanovskified egg
will bud, will proliferate, will divide. From eight to ninety-six buds, and
every bud will grow into a perfectly formed embryo, and every embryo
into a full-sized adult. Making ninety-six human beings grow where
only one grew before. Progress.
"Essentially," the D.H.C. concluded, "bokanovskification consists of a
series of arrests of development. We check the normal growth and,
paradoxically enough, the egg responds by budding."
Responds by budding. The pencils were busy.
He pointed. On a very slowly moving band a rack-full of test-tubes was
entering a large metal box, another, rack-full was emerging. Machinery
faintly purred. It took eight minutes for the tubes to go through, he
told them. Eight minutes of hard X-rays being about as much as an
egg can stand. A few died; of the rest, the least susceptible divided
into two; most put out four buds; some eight; all were returned to the
incubators, where the buds began to develop; then, after two days,
were suddenly chilled, chilled and checked. Two, four, eight, the buds
in their turn budded; and having budded were dosed almost to death
with alcohol; consequently burgeoned again and having budded-bud
out of bud out of bud-were thereafter-further arrest being generally
fatal-left to develop in peace. By which time the original egg was in a
fair way to becoming anything from eight to ninety-six embryos- a
prodigious improvement, you will agree, on nature. Identical twins-but
not in piddling twos and threes as in the old viviparous days, when an
egg would sometimes accidentally divide; actually by dozens, by
scores at a time.
"Scores," the Director repeated and flung out his arms, as though he
were distributing largesse. "Scores."
But one of the students was fool enough to ask where the advantage
lay.
"My good boy!" The Director wheeled sharply round on him. "Can't you
see? Can't you see?" He raised a hand; his expression was solemn.
"Bokanovsky's Process is one of the major instruments of social stabil-
ity!"
Major instruments of social stability.
Standard men and women; in uniform batches. The whole of a small
factory staffed with the products of a single bokanovskified egg.
"Ninety-six identical twins working ninety-six identical machines!" The
voice was almost tremulous with enthusiasm. "You really know where
you are. For the first time in history." He quoted the planetary motto.
"Community, Identity, Stability." Grand words. "If we could bo-
kanovskify indefinitely the whole problem would be solved."
Solved by standard Gammas, unvarying Deltas, uniform Epsilons. Mil-
lions of identical twins. The principle of mass production at last applied
to biology.
"But, alas," the Director shook his head, "we can't bokanovskify indefi-
nitely."
Ninety-six seemed to be the limit; seventy-two a good average. From
the same ovary and with gametes of the same male to manufacture as
many batches of identical twins as possible-that was the best (sadly a
second best) that they could do. And even that was difficult.
"For in nature it takes thirty years for two hundred eggs to reach ma-
turity. But our business is to stabilize the population at this moment,
here and now. Dribbling out twins over a quarter of a century-what
would be the use of that?"
Obviously, no use at all. But Podsnap's Technique had immensely ac-
celerated the process of ripening. They could make sure of at least a
hundred and fifty mature eggs within two years. Fertilize and bo-
kanovskify-in other words, multiply by seventy-two-and you get an
average of nearly eleven thousand brothers and sisters in a hundred
and fifty batches of identical twins, all within two years of the same
age.
"And in exceptional cases we can make one ovary yield us over fifteen
thousand adult individuals."
Beckoning to a fair-haired, ruddy young man who happened to be
passing at the moment. "Mr. Foster," he called. The ruddy young man
approached. "Can you tell us the record for a single ovary, Mr. Foster?"
"Sixteen thousand and twelve in this Centre," Mr. Foster replied with-
out hesitation. He spoke very quickly, had a vivacious blue eye, and
took an evident pleasure in quoting figures. "Sixteen thousand and
twelve; in one hundred and eighty-nine batches of identicals. But of
course they've done much better," he rattled on, "in some of the tropi-
cal Centres. Singapore has often produced over sixteen thousand five
hundred; and Mombasa has actually touched the seventeen thousand
mark. But then they have unfair advantages. You should see the way a
negro ovary responds to pituitary! It's quite astonishing, when you're
used to working with European material. Still," he added, with a laugh
(but the light of combat was in his eyes and the lift of his chin was
challenging), "still, we mean to beat them if we can. I'm working on a
wonderful Delta-Minus ovary at this moment. Only just eighteen
months old. Over twelve thousand seven hundred children already, ei-
ther decanted or in embryo. And still going strong. We'll beat them
yet."
"That's the spirit I like!" cried the Director, and clapped Mr. Foster on
the shoulder. "Come along with us, and give these boys the benefit of
your expert knowledge."
Mr. Foster smiled modestly. "With pleasure." They went.
In the Bottling Room all was harmonious bustle and ordered activity.
Flaps of fresh sow's peritoneum ready cut to the proper size came
shooting up in little lifts from the Organ Store in the sub-basement.
Whizz and then, click! the lift-hatches hew open; the bottle-liner had
only to reach out a hand, take the flap, insert, smooth-down, and be-
fore the lined bottle had had time to travel out of reach along the end-
less band, whizz, click! another flap of peritoneum had shot up from
the depths, ready to be slipped into yet another bottle, the next of that
slow interminable procession on the band.
Next to the Liners stood the Matriculators. The procession advanced;
one by one the eggs were transferred from their test-tubes to the
larger containers; deftly the peritoneal lining was slit, the morula
dropped into place, the saline solution poured in ... and already the
bottle had passed, and it was the turn of the labellers. Heredity, date
of fertilization, membership of Bokanovsky Group-details were trans-
ferred from test-tube to bottle. No longer anonymous, but named,
identified, the procession marched slowly on; on through an opening in
the wall, slowly on into the Social Predestination Room.
"Eighty-eight cubic metres of card-index," said Mr. Foster with relish,
as they entered."""
def create_setup_and_compute(
model_names: List[str],
batch_sizes: List[int],
slice_sizes: List[int],
gpu: bool = True,
tensorflow: bool = False,
average_over: int = 3,
no_speed: bool = False,
no_memory: bool = False,
verbose: bool = False,
torchscript: bool = False,
xla: bool = False,
amp: bool = False,
fp16: bool = False,
save_to_csv: bool = False,
csv_time_filename: str = f"time_{round(time())}.csv",
csv_memory_filename: str = f"memory_{round(time())}.csv",
print_fn: Callable[[str], None] = print,
):
if xla:
tf.config.optimizer.set_jit(True)
if amp:
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
if tensorflow:
dictionary = {model_name: {} for model_name in model_names}
results = _compute_tensorflow(
model_names,
batch_sizes,
slice_sizes,
dictionary,
average_over,
amp,
no_speed,
no_memory,
verbose,
print_fn,
)
else:
device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu"
dictionary = {model_name: {} for model_name in model_names}
results = _compute_pytorch(
model_names,
batch_sizes,
slice_sizes,
dictionary,
average_over,
device,
torchscript,
fp16,
no_speed,
no_memory,
verbose,
print_fn,
)
print_fn("=========== RESULTS ===========")
for model_name in model_names:
print_fn("\t" + f"======= MODEL CHECKPOINT: {model_name} =======")
for batch_size in results[model_name]["bs"]:
print_fn("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
for slice_size in results[model_name]["ss"]:
time = results[model_name]["time"][batch_size][slice_size]
memory = results[model_name]["memory"][batch_size][slice_size]
if isinstance(time, str):
print_fn(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{time} " f"{memory}")
else:
print_fn(
f"\t\t{model_name}/{batch_size}/{slice_size}: "
f"{(round(1000 * time) / 1000)}"
f"s "
f"{memory}"
)
if save_to_csv:
with open(csv_time_filename, mode="w") as csv_time_file, open(
csv_memory_filename, mode="w"
) as csv_memory_file:
assert len(model_names) > 0, "At least 1 model should be defined, but got {}".format(model_names)
fieldnames = ["model", "batch_size", "sequence_length"]
time_writer = csv.DictWriter(csv_time_file, fieldnames=fieldnames + ["time_in_s"])
time_writer.writeheader()
memory_writer = csv.DictWriter(csv_memory_file, fieldnames=fieldnames + ["memory"])
memory_writer.writeheader()
for model_name in model_names:
time_dict = results[model_name]["time"]
memory_dict = results[model_name]["memory"]
for bs in time_dict:
for ss in time_dict[bs]:
time_writer.writerow(
{
"model": model_name,
"batch_size": bs,
"sequence_length": ss,
"time_in_s": "{:.4f}".format(time_dict[bs][ss]),
}
)
for bs in memory_dict:
for ss in time_dict[bs]:
memory_writer.writerow(
{
"model": model_name,
"batch_size": bs,
"sequence_length": ss,
"memory": memory_dict[bs][ss],
}
)
def print_summary_statistics(summary: MemorySummary, print_fn: Callable[[str], None]):
print_fn(
"\nLines by line memory consumption:\n"
+ "\n".join(
f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.sequential
)
)
print_fn(
"\nLines with top memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[:6]
)
)
print_fn(
"\nLines with lowest memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[-6:]
)
)
print_fn(f"\nTotal memory increase: {summary.total}")
def get_print_function(save_print_log, log_filename):
if save_print_log:
logging.basicConfig(
level=logging.DEBUG,
filename=log_filename,
filemode="a+",
format="%(asctime)-15s %(levelname)-8s %(message)s",
)
def print_with_print_log(*args):
logging.info(*args)
print(*args)
return print_with_print_log
else:
return print
def _compute_pytorch(
model_names,
batch_sizes,
slice_sizes,
dictionary,
average_over,
device,
torchscript,
fp16,
no_speed,
no_memory,
verbose,
print_fn,
):
for c, model_name in enumerate(model_names):
print_fn(f"{c + 1} / {len(model_names)}")
config = AutoConfig.from_pretrained(model_name, torchscript=torchscript)
model = AutoModel.from_pretrained(model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)
max_input_size = tokenizer.max_model_input_sizes[model_name]
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "time": {}, "memory": {}}
dictionary[model_name]["time"] = {i: {} for i in batch_sizes}
dictionary[model_name]["memory"] = {i: {} for i in batch_sizes}
print_fn("Using model {}".format(model))
print_fn("Number of all parameters {}".format(model.num_parameters()))
for batch_size in batch_sizes:
if fp16:
model.half()
model.to(device)
model.eval()
for slice_size in slice_sizes:
if max_input_size is not None and slice_size > max_input_size:
dictionary[model_name]["time"][batch_size][slice_size] = "N/A"
else:
sequence = torch.tensor(tokenized_sequence[:slice_size], device=device).repeat(batch_size, 1)
try:
if torchscript:
print_fn("Tracing model with sequence size {}".format(sequence.shape))
inference = torch.jit.trace(model, sequence)
inference(sequence)
else:
inference = model
inference(sequence)
if not no_memory:
# model.add_memory_hooks() # Forward method tracing (only for PyTorch models)
# Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code
trace = start_memory_tracing("transformers")
inference(sequence)
summary = stop_memory_tracing(trace)
if verbose:
print_summary_statistics(summary, print_fn)
dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total)
else:
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
if not no_speed:
print_fn("Going through model with sequence of shape {}".format(sequence.shape))
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["time"][batch_size][slice_size] = average_time
else:
dictionary[model_name]["time"][batch_size][slice_size] = "N/A"
except RuntimeError as e:
print_fn("Doesn't fit on GPU. {}".format(e))
torch.cuda.empty_cache()
dictionary[model_name]["time"][batch_size][slice_size] = "N/A"
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
return dictionary
def _compute_tensorflow(
model_names, batch_sizes, slice_sizes, dictionary, average_over, amp, no_speed, no_memory, verbose, print_fn
):
for c, model_name in enumerate(model_names):
print_fn(f"{c + 1} / {len(model_names)}")
config = AutoConfig.from_pretrained(model_name)
model = TFAutoModel.from_pretrained(model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)
max_input_size = tokenizer.max_model_input_sizes[model_name]
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "time": {}, "memory": {}}
dictionary[model_name]["time"] = {i: {} for i in batch_sizes}
dictionary[model_name]["memory"] = {i: {} for i in batch_sizes}
print_fn("Using model {}".format(model))
print_fn("Number of all parameters {}".format(model.num_parameters()))
@tf.function
def inference(inputs):
return model(inputs)
for batch_size in batch_sizes:
for slice_size in slice_sizes:
if max_input_size is not None and slice_size > max_input_size:
dictionary[model_name]["time"][batch_size][slice_size] = "N/A"
else:
sequence = tf.stack(
[tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size
)
try:
print_fn("Going through model with sequence of shape {}".format(sequence.shape))
# To make sure that the model is traced + that the tensors are on the appropriate device
inference(sequence)
if not no_memory:
# Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code
trace = start_memory_tracing("transformers")
inference(sequence)
summary = stop_memory_tracing(trace)
if verbose:
print_summary_statistics(summary, print_fn)
dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total)
else:
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
if not no_speed:
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["time"][batch_size][slice_size] = average_time
else:
dictionary[model_name]["time"][batch_size][slice_size] = "N/A"
except tf.errors.ResourceExhaustedError as e:
print_fn("Doesn't fit on GPU. {}".format(e))
dictionary[model_name]["time"][batch_size][slice_size] = "N/A"
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
return dictionary
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
required=False,
type=str,
default="all",
help="Model checkpoints to be provided "
"to the AutoModel classes. Leave "
"blank to benchmark the base version "
"of all available model "
"architectures.",
)
parser.add_argument("--verbose", required=False, action="store_true", help="Verbose memory tracing")
parser.add_argument("--no_speed", required=False, action="store_true", help="Don't perform speed measurments")
parser.add_argument("--no_memory", required=False, action="store_true", help="Don't perform memory measurments")
parser.add_argument(
"--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models"
)
parser.add_argument(
"--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available " "cuda devices"
)
parser.add_argument(
"--torchscript",
required=False,
action="store_true",
help="Pytorch only: trace the models " "using torchscript",
)
parser.add_argument(
"--tensorflow",
required=False,
action="store_true",
help="Benchmark the TensorFlow version "
"of the models. Will run on GPU if "
"the correct dependencies are "
"installed",
)
parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.")
parser.add_argument(
"--amp",
required=False,
action="store_true",
help="TensorFlow only: use automatic mixed precision acceleration.",
)
parser.add_argument(
"--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference."
)
parser.add_argument(
"--keras_predict",
required=False,
action="store_true",
help="Whether to use model.predict " "instead of model() to do a " "forward pass.",
)
parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.")
parser.add_argument(
"--log_print", required=False, action="store_true", help="Save all print statements in log file."
)
parser.add_argument(
"--csv_time_filename",
required=False,
default=f"time_{round(time())}.csv",
help="CSV filename used if saving time results to csv.",
)
parser.add_argument(
"--csv_memory_filename",
required=False,
default=f"memory_{round(time())}.csv",
help="CSV filename used if saving memory results to csv.",
)
parser.add_argument(
"--log_filename",
required=False,
default=f"log_{round(time())}.txt",
help="Log filename used if print statements are saved in log.",
)
parser.add_argument(
"--average_over", required=False, default=30, type=int, help="Times an experiment will be run."
)
parser.add_argument("--batch_sizes", nargs="+", type=int, default=[1, 2, 4, 8])
parser.add_argument("--slice_sizes", nargs="+", type=int, default=[8, 64, 128, 256, 512, 1024])
args = parser.parse_args()
if args.models == "all":
args.models = [
"gpt2",
"bert-base-cased",
"xlnet-base-cased",
"xlm-mlm-en-2048",
"transfo-xl-wt103",
"openai-gpt",
"distilbert-base-uncased",
"distilgpt2",
"roberta-base",
"ctrl",
"t5-base",
"bart-large",
]
else:
args.models = args.models.split()
print_fn = get_print_function(args.log_print, args.log_filename)
print_fn("Running with arguments: {}".format(args))
if args.torch:
if is_torch_available():
create_setup_and_compute(
model_names=args.models,
batch_sizes=args.batch_sizes,
slice_sizes=args.slice_sizes,
tensorflow=False,
gpu=args.torch_cuda,
torchscript=args.torchscript,
fp16=args.fp16,
save_to_csv=args.save_to_csv,
csv_time_filename=args.csv_time_filename,
csv_memory_filename=args.csv_memory_filename,
average_over=args.average_over,
no_speed=args.no_speed,
no_memory=args.no_memory,
verbose=args.verbose,
print_fn=print_fn,
)
else:
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
if args.tensorflow:
if is_tf_available():
create_setup_and_compute(
model_names=args.models,
batch_sizes=args.batch_sizes,
slice_sizes=args.slice_sizes,
tensorflow=True,
xla=args.xla,
amp=args.amp,
save_to_csv=args.save_to_csv,
csv_time_filename=args.csv_time_filename,
csv_memory_filename=args.csv_memory_filename,
average_over=args.average_over,
no_speed=args.no_speed,
no_memory=args.no_memory,
verbose=args.verbose,
print_fn=print_fn,
)
else:
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
if __name__ == "__main__":
main()
......@@ -6,3 +6,4 @@ sacrebleu
rouge-score
tensorflow_datasets
pytorch-lightning==0.7.3 # April 10, 2020 release
matplotlib
......@@ -19,19 +19,6 @@ else:
import logging
# Benchmarking
from .benchmark_utils import (
Frame,
Memory,
MemoryState,
MemorySummary,
MemoryTrace,
UsedMemoryState,
bytes_to_human_readable,
start_memory_tracing,
stop_memory_tracing,
)
# Configurations
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
......@@ -358,6 +345,9 @@ if is_torch_available():
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
# Benchmarks
from .benchmark import PyTorchBenchmark, PyTorchBenchmarkArguments
# TensorFlow
if is_tf_available():
from .modeling_tf_utils import (
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ..file_utils import is_torch_available
if is_torch_available():
from .benchmark_args import PyTorchBenchmarkArguments
from .benchmark import PyTorchBenchmark
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Benchmarking the library on inference and training in PyTorch.
"""
import inspect
import logging
import timeit
from transformers import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, PretrainedConfig, is_torch_available
from .benchmark_utils import Benchmark, Memory, start_memory_tracing, stop_memory_tracing
if is_torch_available():
import torch
from .benchmark_args import PyTorchBenchmarkArguments
logger = logging.getLogger(__name__)
class PyTorchBenchmark(Benchmark):
args: PyTorchBenchmarkArguments
configs: PretrainedConfig
framework: str = "PyTorch"
@property
def framework_version(self):
return torch.__version__
def train(self, model_name, batch_size, sequence_length, trace_memory=False):
try:
config = self.config_dict[model_name]
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
model.to(self.args.device)
model.train()
input_ids = torch.randint(
model.config.vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
)
def compute_loss_and_backprob():
# TODO: Not all models call labels argument labels => this hack using the function signature should be corrected once all models have a common name for labels
function_argument_names = inspect.getfullargspec(model.forward).args
if "labels" in function_argument_names:
loss = model(input_ids, labels=input_ids)[0]
elif "lm_labels" in function_argument_names:
loss = model(input_ids, lm_labels=input_ids)[0]
elif "masked_lm_labels" in function_argument_names:
loss = model(input_ids, masked_lm_labels=input_ids)[0]
else:
NotImplementedError(f"{model_name} does not seem to allow training with labels")
loss.backward()
model.zero_grad()
if trace_memory is True:
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
trace = start_memory_tracing("transformers")
else:
# clear cuda cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# calculate loss and do backpropagation
compute_loss_and_backprob()
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
summary = stop_memory_tracing(trace)
memory = summary.total
else:
memory = Memory(torch.cuda.max_memory_reserved())
return memory
else:
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(lambda: compute_loss_and_backprob(), repeat=self.args.repeat, number=10,)
return min(runtimes) / 10.0
except RuntimeError as e:
self.print_fn("Doesn't fit on GPU. {}".format(e))
return "N/A"
def inference(self, model_name, batch_size, sequence_length, trace_memory=False):
try:
config = self.config_dict[model_name]
model = MODEL_MAPPING[config.__class__](config)
model.to(self.args.device)
model.eval()
input_ids = torch.randint(
config.vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
)
if trace_memory is True:
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
trace = start_memory_tracing("transformers")
else:
# clear cuda cache
torch.cuda.empty_cache()
if hasattr(torch.cuda, "max_memory_reserved"):
torch.cuda.reset_peak_memory_stats()
else:
logger.info(
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
)
torch.cuda.reset_max_memory_cached()
model(input_ids)
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
summary = stop_memory_tracing(trace)
memory = summary.total
else:
if hasattr(torch.cuda, "max_memory_reserved"):
memory = Memory(torch.cuda.max_memory_reserved())
else:
logger.info(
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
)
memory = Memory(torch.cuda.max_memory_cached())
return memory
else:
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(lambda: model(input_ids), repeat=self.args.repeat, number=10,)
return min(runtimes) / 10.0
except RuntimeError as e:
self.print_fn("Doesn't fit on GPU. {}".format(e))
return "N/A"
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from typing import Tuple
from ..file_utils import cached_property, is_torch_available, torch_required
from .benchmark_args_utils import BenchmarkArguments
if is_torch_available():
import torch
try:
import torch_xla.core.xla_model as xm
_has_tpu = True
except ImportError:
_has_tpu = False
@torch_required
def is_tpu_available():
return _has_tpu
logger = logging.getLogger(__name__)
@dataclass
class PyTorchBenchmarkArguments(BenchmarkArguments):
no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"})
torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
@cached_property
@torch_required
def _setup_devices(self) -> Tuple["torch.device", int]:
logger.info("PyTorch: setting up devices")
if self.no_cuda:
device = torch.device("cpu")
n_gpu = 0
elif is_tpu_available():
device = xm.xla_device()
n_gpu = 0
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
return device, n_gpu
@property
@torch_required
def device_idx(self) -> int:
return torch.cuda.current_device()
@property
@torch_required
def device(self) -> "torch.device":
return self._setup_devices[0]
@property
@torch_required
def n_gpu(self):
return self._setup_devices[1]
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
from dataclasses import dataclass, field
from time import time
from typing import List
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
@dataclass
class BenchmarkArguments:
"""
BenchMarkArguments are arguments we use in our benchmark scripts
**which relate to the training loop itself**.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
models: List[str] = list_field(
default=[],
metadata={
"help": "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version of all available models"
},
)
batch_sizes: List[int] = list_field(
default=[8], metadata={"help": "List of batch sizes for which memory and time performance will be evaluated"}
)
sequence_lengths: List[int] = list_field(
default=[8, 32, 128, 512],
metadata={"help": "List of sequence lengths for which memory and time performance will be evaluated"},
)
no_inference: bool = field(default=False, metadata={"help": "Don't benchmark inference of model"})
training: bool = field(default=False, metadata={"help": "Benchmark training of model"})
verbose: bool = field(default=False, metadata={"help": "Verbose memory tracing"})
no_speed: bool = field(default=False, metadata={"help": "Don't perform speed measurments"})
no_memory: bool = field(default=False, metadata={"help": "Don't perform memory measurments"})
trace_memory_line_by_line: bool = field(default=False, metadata={"help": "Trace memory line by line"})
save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
no_env_print: bool = field(default=False, metadata={"help": "Don't print environment information"})
inference_time_csv_file: str = field(
default=f"inference_time_{round(time())}.csv",
metadata={"help": "CSV filename used if saving time results to csv."},
)
inference_memory_csv_file: str = field(
default=f"inference_memory_{round(time())}.csv",
metadata={"help": "CSV filename used if saving memory results to csv."},
)
train_time_csv_file: str = field(
default=f"train_time_{round(time())}.csv",
metadata={"help": "CSV filename used if saving time results to csv for training."},
)
train_memory_csv_file: str = field(
default=f"train_memory_{round(time())}.csv",
metadata={"help": "CSV filename used if saving memory results to csv for training."},
)
env_info_csv_file: str = field(
default=f"env_info_{round(time())}.csv",
metadata={"help": "CSV filename used if saving environment information."},
)
log_filename: str = field(
default=f"log_{round(time())}.csv",
metadata={"help": "Log filename used if print statements are saved in log."},
)
repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
def to_json_string(self):
"""
Serializes this instance to a JSON string.
"""
return json.dumps(dataclasses.asdict(self), indent=2)
@property
def model_names(self):
return self.models
......@@ -4,18 +4,28 @@ This file is adapted from the AllenNLP library at https://github.com/allenai/all
Copyright by the AllenNLP authors.
"""
import copy
import csv
import linecache
import logging
import os
import platform
import sys
from collections import defaultdict
from abc import ABC, abstractmethod
from collections import defaultdict, namedtuple
from datetime import datetime
from typing import Iterable, List, NamedTuple, Optional, Union
from .file_utils import is_tf_available, is_torch_available
from transformers import AutoConfig, PretrainedConfig
from transformers import __version__ as version
from ..file_utils import is_tf_available, is_torch_available
from .benchmark_args_utils import BenchmarkArguments
if is_torch_available():
from torch.cuda import empty_cache as torch_empty_cache
if is_tf_available():
from tensorflow.python.eager import context as tf_context
......@@ -25,6 +35,10 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_is_memory_tracing_enabled = False
BenchmarkOutput = namedtuple(
"BenchmarkOutput", ["time_inference_result", "memory_inference_result", "time_train_result", "memory_train_result"]
)
def is_memory_tracing_enabled():
global _is_memory_tracing_enabled
......@@ -62,14 +76,14 @@ class UsedMemoryState(NamedTuple):
class Memory(NamedTuple):
""" `Memory` NamedTuple have a single field `bytes` and
you can get a human readable string of the number of bytes by calling `__repr__`
you can get a human readable str of the number of mega bytes by calling `__repr__`
- `byte` (integer): number of bytes,
"""
bytes: int
def __repr__(self) -> str:
return bytes_to_human_readable(self.bytes)
return str(bytes_to_mega_bytes(self.bytes))
class MemoryState(NamedTuple):
......@@ -99,6 +113,7 @@ class MemorySummary(NamedTuple):
sequential: List[MemoryState]
cumulative: List[MemoryState]
current: List[MemoryState]
total: Memory
......@@ -234,10 +249,12 @@ def start_memory_tracing(
# Sum used memory for all GPUs
py3nvml.nvmlInit()
for i in devices:
handle = py3nvml.nvmlDeviceGetHandleByIndex(i)
meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem += meminfo.used
py3nvml.nvmlShutdown()
mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
......@@ -295,8 +312,11 @@ def stop_memory_tracing(
if memory_trace is not None and len(memory_trace) > 1:
memory_diff_trace = []
memory_curr_trace = []
cumulative_memory_dict = defaultdict(lambda: [0, 0, 0])
for (frame, cpu_mem, gpu_mem), (next_frame, next_cpu_mem, next_gpu_mem) in zip(
for ((frame, cpu_mem, gpu_mem), (next_frame, next_cpu_mem, next_gpu_mem),) in zip(
memory_trace[:-1], memory_trace[1:]
):
cpu_mem_inc = next_cpu_mem - cpu_mem
......@@ -307,6 +327,16 @@ def stop_memory_tracing(
frame=frame, cpu=Memory(cpu_mem_inc), gpu=Memory(gpu_mem_inc), cpu_gpu=Memory(cpu_gpu_mem_inc),
)
)
memory_curr_trace.append(
MemoryState(
frame=frame,
cpu=Memory(next_cpu_mem),
gpu=Memory(next_gpu_mem),
cpu_gpu=Memory(next_gpu_mem + next_cpu_mem),
)
)
cumulative_memory_dict[frame][0] += cpu_mem_inc
cumulative_memory_dict[frame][1] += gpu_mem_inc
cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
......@@ -321,21 +351,287 @@ def stop_memory_tracing(
for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
)
memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True)
if ignore_released_memory:
total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace)
else:
total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace)
total_memory = Memory(total_memory)
return MemorySummary(sequential=memory_diff_trace, cumulative=cumulative_memory, total=total_memory)
return MemorySummary(
sequential=memory_diff_trace, cumulative=cumulative_memory, current=memory_curr_trace, total=total_memory,
)
return None
def bytes_to_human_readable(memory_amount: int) -> str:
""" Utility to convert a number of bytes (int) in a human readable string (with units)
def bytes_to_mega_bytes(memory_amount: int) -> int:
""" Utility to convert a number of bytes (int) into a number of mega bytes (int)
"""
return memory_amount >> 20
class Benchmark(ABC):
"""
Benchmarks is a simple but feature-complete benchmarking script
to compare memory and time performance of models in Transformers.
"""
for unit in ["B", "KB", "MB", "GB"]:
if memory_amount > -1024.0 and memory_amount < 1024.0:
return "{:.3f}{}".format(memory_amount, unit)
memory_amount /= 1024.0
return "{:.3f}TB".format(memory_amount)
args: BenchmarkArguments
configs: PretrainedConfig
framework: str
def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig = None):
self.args = args
if configs is None:
self.config_dict = {
model_name: AutoConfig.from_pretrained(model_name) for model_name in self.args.model_names
}
else:
self.config_dict = {model_name: config for model_name, config in zip(self.args.model_names, configs)}
self._print_fn = None
self._framework_version = None
self._environment_info = None
@property
def print_fn(self):
if self._print_fn is None:
if self.args.log_print:
logging.basicConfig(
level=logging.DEBUG,
filename=self.args.log_filename,
filemode="a+",
format="%(asctime)-15s %(levelname)-8s %(message)s",
)
def print_and_log(*args):
logging.info(*args)
print(*args)
self._print_fn = print_and_log
else:
self._print_fn = print
return self._print_fn
@property
def is_gpu(self):
return self.args.n_gpu > 0
@property
@abstractmethod
def framework_version(self):
pass
@abstractmethod
def train(self, model_name, batch_size, sequence_length):
pass
@abstractmethod
def inference(self, model_name, batch_size, sequence_length):
pass
def run(self):
result_dict = {model_name: {} for model_name in self.args.model_names}
inference_result_time = copy.deepcopy(result_dict)
inference_result_memory = copy.deepcopy(result_dict)
train_result_time = copy.deepcopy(result_dict)
train_result_memory = copy.deepcopy(result_dict)
for c, model_name in enumerate(self.args.model_names):
self.print_fn(f"{c + 1} / {len(self.args.model_names)}")
model_dict = {
"bs": self.args.batch_sizes,
"ss": self.args.sequence_lengths,
"result": {i: {} for i in self.args.batch_sizes},
}
inference_result_time[model_name] = copy.deepcopy(model_dict)
inference_result_memory[model_name] = copy.deepcopy(model_dict)
train_result_time[model_name] = copy.deepcopy(model_dict)
train_result_memory[model_name] = copy.deepcopy(model_dict)
for batch_size in self.args.batch_sizes:
for sequence_length in self.args.sequence_lengths:
if not self.args.no_inference:
if not self.args.no_memory:
memory = self.inference(model_name, batch_size, sequence_length, trace_memory=True)
inference_result_memory[model_name]["result"][batch_size][sequence_length] = memory
if not self.args.no_speed:
time = self.inference(model_name, batch_size, sequence_length, trace_memory=False)
inference_result_time[model_name]["result"][batch_size][sequence_length] = time
if self.args.training:
if not self.args.no_memory:
memory = self.train(model_name, batch_size, sequence_length, trace_memory=True)
train_result_memory[model_name]["result"][batch_size][sequence_length] = memory
if not self.args.no_speed:
time = self.inference(model_name, batch_size, sequence_length, trace_memory=False)
train_result_time[model_name]["result"][batch_size][sequence_length] = time
if not self.args.no_inference:
if not self.args.no_speed:
self.print_fn("======= INFERENCE - SPEED - RESULT =======")
self.print_results(inference_result_time)
self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
if not self.args.no_memory:
self.print_fn("======= INFERENCE - MEMORY - RESULT =======")
self.print_results(inference_result_memory)
self.save_to_csv(inference_result_memory, self.args.inference_memory_csv_file)
if self.args.training:
if not self.args.no_speed:
self.print_fn("======= TRAIN - SPEED - RESULT =======")
self.print_results(train_result_time)
self.save_to_csv(train_result_time, self.args.train_time_csv_file)
if not self.args.no_memory:
self.print_fn("======= TRAIN - MEMORY - RESULT =======")
self.print_results(train_result_memory)
self.save_to_csv(train_result_memory, self.args.train_memory_csv_file)
if not self.args.no_env_print:
self.print_fn("\n======== ENVIRONMENT - INFORMATION ========")
self.print_fn(
"\n".join(["- {}: {}".format(prop, val) for prop, val in self.environment_info.items()]) + "\n"
)
if self.args.save_to_csv:
with open(self.args.env_info_csv_file, mode="w", newline="") as csv_file:
writer = csv.writer(csv_file)
for key, value in self.environment_info.items():
writer.writerow([key, value])
return BenchmarkOutput(inference_result_time, inference_result_memory, train_result_time, train_result_memory)
@property
def environment_info(self):
if self._environment_info is None:
info = {}
info["transformers_version"] = version
info["framework"] = self.framework
info["framework_version"] = self.framework_version
info["python_version"] = platform.python_version()
info["system"] = platform.system()
info["cpu"] = platform.processor()
info["architecture"] = platform.architecture()[0]
info["date"] = datetime.date(datetime.now())
info["time"] = datetime.time(datetime.now())
try:
import psutil
except (ImportError):
logger.warning(
"Psutil not installed, we won't log available CPU memory."
"Install psutil (pip install psutil) to log available CPU memory."
)
info["cpu_ram_mb"] = "N/A"
else:
info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
info["use_gpu"] = self.is_gpu
if self.is_gpu:
info["num_gpus"] = self.args.n_gpu
try:
from py3nvml import py3nvml
py3nvml.nvmlInit()
handle = py3nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
except ImportError:
logger.warning(
"py3nvml not installed, we won't log GPU memory usage. "
"Install py3nvml (pip install py3nvml) to log information about GPU."
)
info["gpu"] = "N/A"
info["gpu_ram_mb"] = "N/A"
info["gpu_power_watts"] = "N/A"
info["gpu_performance_state"] = "N/A"
except (OSError, py3nvml.NVMLError):
logger.warning(
"Error while initializing comunication with GPU. " "We won't log information about GPU."
)
info["gpu"] = "N/A"
info["gpu_ram_mb"] = "N/A"
info["gpu_power_watts"] = "N/A"
info["gpu_performance_state"] = "N/A"
py3nvml.nvmlShutdown()
else:
info["gpu"] = py3nvml.nvmlDeviceGetName(handle)
info["gpu_ram_mb"] = bytes_to_mega_bytes(py3nvml.nvmlDeviceGetMemoryInfo(handle).total)
info["gpu_power_watts"] = py3nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000
info["gpu_performance_state"] = py3nvml.nvmlDeviceGetPerformanceState(handle)
py3nvml.nvmlShutdown()
self._environment_info = info
return self._environment_info
def print_results(self, result_dict):
for model_name in self.args.model_names:
self.print_fn("\t" + f"======= MODEL CHECKPOINT: {model_name} =======")
for batch_size in result_dict[model_name]["bs"]:
for sequence_length in result_dict[model_name]["ss"]:
result = result_dict[model_name]["result"][batch_size][sequence_length]
if isinstance(result, float):
self.print_fn(
f"\t\t{model_name}/{batch_size}/{sequence_length}: " f"{(round(1000 * result) / 1000)}s"
)
else:
self.print_fn(f"\t\t{model_name}/{batch_size}/{sequence_length}: " f"{result} MB")
def print_memory_trace_statistics(self, summary: MemorySummary):
self.print_fn(
"\nLine by line memory consumption:\n"
+ "\n".join(
f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.sequential
)
)
self.print_fn(
"\nLines with top memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[:6]
)
)
self.print_fn(
"\nLines with lowest memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[-6:]
)
)
self.print_fn(f"\nTotal memory increase: {summary.total}")
def save_to_csv(self, result_dict, filename):
if not self.args.save_to_csv:
return
self.print_fn("Saving results to csv.")
with open(filename, mode="w") as csv_file:
assert len(self.args.model_names) > 0, "At least 1 model should be defined, but got {}".format(
self.model_names
)
fieldnames = ["model", "batch_size", "sequence_length"]
writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
writer.writeheader()
for model_name in self.args.model_names:
result_dict_model = result_dict[model_name]["result"]
for bs in result_dict_model:
for ss in result_dict_model[bs]:
result_model = result_dict_model[bs][ss]
writer.writerow(
{
"model": model_name,
"batch_size": bs,
"sequence_length": ss,
"result": ("{}" if not isinstance(result_model, float) else "{:.4f}").format(
result_model
),
}
)
......@@ -59,6 +59,7 @@ try:
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
from torch.hub import _get_torch_home
......
......@@ -4,7 +4,7 @@ import sys
from argparse import ArgumentParser
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, NewType, Tuple, Union
from typing import Any, Iterable, List, NewType, Tuple, Union
DataClass = NewType("DataClass", Any)
......@@ -52,9 +52,13 @@ class HfArgumentParser(ArgumentParser):
"We will add compatibility when Python 3.9 is released."
)
typestring = str(field.type)
for x in (int, float, str):
if typestring == f"typing.Union[{x.__name__}, NoneType]":
field.type = x
for prim_type in (int, float, str):
for collection in (List,):
if typestring == f"typing.Union[{collection[prim_type]}, NoneType]":
field.type = collection[prim_type]
if typestring == f"typing.Union[{prim_type.__name__}, NoneType]":
field.type = prim_type
if isinstance(field.type, type) and issubclass(field.type, Enum):
kwargs["choices"] = list(field.type)
kwargs["type"] = field.type
......@@ -65,6 +69,14 @@ class HfArgumentParser(ArgumentParser):
if field.default is True:
field_name = f"--no-{field.name}"
kwargs["dest"] = field.name
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+"
kwargs["type"] = field.type.__args__[0]
assert all(
x == kwargs["type"] for x in field.type.__args__
), "{} cannot be a List of mixed types".format(field.name)
if field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
else:
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING:
......
import os
import tempfile
import unittest
from pathlib import Path
from transformers import GPT2Config, is_torch_available
from .utils import require_torch
if is_torch_available():
from transformers import (
PyTorchBenchmarkArguments,
PyTorchBenchmark,
)
@require_torch
class BenchmarkTest(unittest.TestCase):
def check_results_dict_not_empty(self, results):
for model_result in results.values():
for batch_size, sequence_length in zip(model_result["bs"], model_result["ss"]):
result = model_result["result"][batch_size][sequence_length]
self.assertIsNotNone(result)
def test_inference_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1]
)
benchmark = PyTorchBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_train_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1]
)
benchmark = PyTorchBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)
def test_inference_with_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = GPT2Config.from_pretrained(MODEL_ID)
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1]
)
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_train_with_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = GPT2Config.from_pretrained(MODEL_ID)
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1]
)
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)
def test_save_csv_files(self):
MODEL_ID = "sshleifer/tiny-gpt2"
with tempfile.TemporaryDirectory() as tmp_dir:
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=True,
no_inference=False,
save_to_csv=True,
sequence_lengths=[8],
batch_sizes=[1],
inference_time_csv_file=os.path.join(tmp_dir, "inf_time.csv"),
train_memory_csv_file=os.path.join(tmp_dir, "train_mem.csv"),
inference_memory_csv_file=os.path.join(tmp_dir, "inf_mem.csv"),
train_time_csv_file=os.path.join(tmp_dir, "train_time.csv"),
env_info_csv_file=os.path.join(tmp_dir, "env.csv"),
)
benchmark = PyTorchBenchmark(benchmark_args)
benchmark.run()
self.assertTrue(Path(os.path.join(tmp_dir, "inf_time.csv")).exists())
self.assertTrue(Path(os.path.join(tmp_dir, "train_time.csv")).exists())
self.assertTrue(Path(os.path.join(tmp_dir, "inf_mem.csv")).exists())
self.assertTrue(Path(os.path.join(tmp_dir, "train_mem.csv")).exists())
self.assertTrue(Path(os.path.join(tmp_dir, "env.csv")).exists())
......@@ -3,11 +3,15 @@ import unittest
from argparse import Namespace
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from typing import List, Optional
from transformers import HfArgumentParser, TrainingArguments
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
@dataclass
class BasicExample:
foo: int
......@@ -43,6 +47,16 @@ class OptionalExample:
foo: Optional[int] = None
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
baz: Optional[str] = None
ces: Optional[List[str]] = list_field(default=[])
des: Optional[List[int]] = list_field(default=[])
@dataclass
class ListExample:
foo_int: List[int] = list_field(default=[])
bar_int: List[int] = list_field(default=[1, 2, 3])
foo_str: List[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
class HfArgumentParserTest(unittest.TestCase):
......@@ -101,6 +115,26 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, BasicEnum.titi)
def test_with_list(self):
parser = HfArgumentParser(ListExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo_int", nargs="+", default=[], type=int)
expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int)
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(
args,
Namespace(foo_int=[], bar_int=[1, 2, 3], foo_str=["Hallo", "Bonjour", "Hello"], foo_float=[0.1, 0.2, 0.3]),
)
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
def test_with_optional(self):
parser = HfArgumentParser(OptionalExample)
......@@ -108,13 +142,15 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo", default=None, type=int)
expected.add_argument("--bar", default=None, type=float, help="help message")
expected.add_argument("--baz", default=None, type=str)
expected.add_argument("--ces", nargs="+", default=[], type=str)
expected.add_argument("--des", nargs="+", default=[], type=int)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None))
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42"))
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment