Unverified Commit 96f57c9c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Benchmark] Memory benchmark utils (#4198)



* improve memory benchmarking

* correct typo

* fix current memory

* check torch memory allocated

* better pytorch function

* add total cached gpu memory

* add total gpu required

* improve torch gpu usage

* update memory usage

* finalize memory tracing

* save intermediate benchmark class

* fix conflict

* improve benchmark

* improve benchmark

* finalize

* make style

* improve benchmarking

* correct typo

* make train function more flexible

* fix csv save

* better repr of bytes

* better print

* fix __repr__ bug

* finish plot script

* rename plot file

* delete csv and small improvements

* fix in plot

* fix in plot

* correct usage of timeit

* remove redundant line

* remove redundant line

* fix bug

* add hf parser tests

* add versioning and platform info

* make style

* add gpu information

* ensure backward compatibility

* finish adding all tests

* Update src/transformers/benchmark/benchmark_args.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/benchmark/benchmark_args_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* delete csv files

* fix isort ordering

* add out of memory handling

* add better train memory handling
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent ec4cdfdd
import csv
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import matplotlib.pyplot as plt
from transformers import HfArgumentParser
@dataclass
class PlotArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
csv_file: str = field(metadata={"help": "The csv file to plot."},)
plot_along_batch: bool = field(
default=False,
metadata={"help": "Whether to plot along batch size or sequence lengh. Defaults to sequence length."},
)
is_time: bool = field(
default=False,
metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."},
)
is_train: bool = field(
default=False,
metadata={
"help": "Whether the csv file has training results or inference results. Defaults to inference results."
},
)
figure_png_file: Optional[str] = field(
default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
)
class Plot:
def __init__(self, args):
self.args = args
self.result_dict = defaultdict(lambda: dict(bsz=[], seq_len=[], result={}))
with open(self.args.csv_file, newline="") as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
model_name = row["model"]
self.result_dict[model_name]["bsz"].append(int(row["batch_size"]))
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"]))
self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = row[
"result"
]
def plot(self):
fig, ax = plt.subplots()
title_str = "Time usage" if self.args.is_time else "Memory usage"
title_str = title_str + " for training" if self.args.is_train else title_str + " for inference"
for model_name in self.result_dict.keys():
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
results = self.result_dict[model_name]["result"]
(x_axis_array, inner_loop_array) = (
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
)
plt.xlim(min(x_axis_array), max(x_axis_array))
for inner_loop_value in inner_loop_array:
if self.args.plot_along_batch:
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int)
else:
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32)
ax.set_xscale("log", basex=2)
ax.set_yscale("log", basey=10)
(x_axis_label, inner_loop_label) = (
("batch_size", "sequence_length in #tokens")
if self.args.plot_along_batch
else ("sequence_length in #tokens", "batch_size")
)
x_axis_array = np.asarray(x_axis_array, np.int)
plt.scatter(x_axis_array, y_axis_array, label=f"{model_name} - {inner_loop_label}: {inner_loop_value}")
plt.plot(x_axis_array, y_axis_array, "--")
title_str += f" {model_name} vs."
title_str = title_str[:-4]
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB"
# plot
plt.title(title_str)
plt.xlabel(x_axis_label)
plt.ylabel(y_axis_label)
plt.legend()
if self.args.figure_png_file is not None:
plt.savefig(self.args.figure_png_file)
else:
plt.show()
def main():
parser = HfArgumentParser(PlotArguments)
plot_args = parser.parse_args_into_dataclasses()[0]
plot = Plot(args=plot_args)
plot.plot()
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the library on inference and training """
from transformers import HfArgumentParser, PyTorchBenchmark, PyTorchBenchmarkArguments
def main():
parser = HfArgumentParser(PyTorchBenchmarkArguments)
benchmark_args = parser.parse_args_into_dataclasses()[0]
benchmark = PyTorchBenchmark(args=benchmark_args)
benchmark.run()
if __name__ == "__main__":
main()
This diff is collapsed.
......@@ -6,3 +6,4 @@ sacrebleu
rouge-score
tensorflow_datasets
pytorch-lightning==0.7.3 # April 10, 2020 release
matplotlib
......@@ -19,19 +19,6 @@ else:
import logging
# Benchmarking
from .benchmark_utils import (
Frame,
Memory,
MemoryState,
MemorySummary,
MemoryTrace,
UsedMemoryState,
bytes_to_human_readable,
start_memory_tracing,
stop_memory_tracing,
)
# Configurations
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
......@@ -358,6 +345,9 @@ if is_torch_available():
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
# Benchmarks
from .benchmark import PyTorchBenchmark, PyTorchBenchmarkArguments
# TensorFlow
if is_tf_available():
from .modeling_tf_utils import (
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ..file_utils import is_torch_available
if is_torch_available():
from .benchmark_args import PyTorchBenchmarkArguments
from .benchmark import PyTorchBenchmark
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Benchmarking the library on inference and training in PyTorch.
"""
import inspect
import logging
import timeit
from transformers import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, PretrainedConfig, is_torch_available
from .benchmark_utils import Benchmark, Memory, start_memory_tracing, stop_memory_tracing
if is_torch_available():
import torch
from .benchmark_args import PyTorchBenchmarkArguments
logger = logging.getLogger(__name__)
class PyTorchBenchmark(Benchmark):
args: PyTorchBenchmarkArguments
configs: PretrainedConfig
framework: str = "PyTorch"
@property
def framework_version(self):
return torch.__version__
def train(self, model_name, batch_size, sequence_length, trace_memory=False):
try:
config = self.config_dict[model_name]
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
model.to(self.args.device)
model.train()
input_ids = torch.randint(
model.config.vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
)
def compute_loss_and_backprob():
# TODO: Not all models call labels argument labels => this hack using the function signature should be corrected once all models have a common name for labels
function_argument_names = inspect.getfullargspec(model.forward).args
if "labels" in function_argument_names:
loss = model(input_ids, labels=input_ids)[0]
elif "lm_labels" in function_argument_names:
loss = model(input_ids, lm_labels=input_ids)[0]
elif "masked_lm_labels" in function_argument_names:
loss = model(input_ids, masked_lm_labels=input_ids)[0]
else:
NotImplementedError(f"{model_name} does not seem to allow training with labels")
loss.backward()
model.zero_grad()
if trace_memory is True:
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
trace = start_memory_tracing("transformers")
else:
# clear cuda cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# calculate loss and do backpropagation
compute_loss_and_backprob()
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
summary = stop_memory_tracing(trace)
memory = summary.total
else:
memory = Memory(torch.cuda.max_memory_reserved())
return memory
else:
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(lambda: compute_loss_and_backprob(), repeat=self.args.repeat, number=10,)
return min(runtimes) / 10.0
except RuntimeError as e:
self.print_fn("Doesn't fit on GPU. {}".format(e))
return "N/A"
def inference(self, model_name, batch_size, sequence_length, trace_memory=False):
try:
config = self.config_dict[model_name]
model = MODEL_MAPPING[config.__class__](config)
model.to(self.args.device)
model.eval()
input_ids = torch.randint(
config.vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
)
if trace_memory is True:
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
trace = start_memory_tracing("transformers")
else:
# clear cuda cache
torch.cuda.empty_cache()
if hasattr(torch.cuda, "max_memory_reserved"):
torch.cuda.reset_peak_memory_stats()
else:
logger.info(
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
)
torch.cuda.reset_max_memory_cached()
model(input_ids)
if self.args.trace_memory_line_by_line or self.args.n_gpu == 0:
summary = stop_memory_tracing(trace)
memory = summary.total
else:
if hasattr(torch.cuda, "max_memory_reserved"):
memory = Memory(torch.cuda.max_memory_reserved())
else:
logger.info(
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
)
memory = Memory(torch.cuda.max_memory_cached())
return memory
else:
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(lambda: model(input_ids), repeat=self.args.repeat, number=10,)
return min(runtimes) / 10.0
except RuntimeError as e:
self.print_fn("Doesn't fit on GPU. {}".format(e))
return "N/A"
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from typing import Tuple
from ..file_utils import cached_property, is_torch_available, torch_required
from .benchmark_args_utils import BenchmarkArguments
if is_torch_available():
import torch
try:
import torch_xla.core.xla_model as xm
_has_tpu = True
except ImportError:
_has_tpu = False
@torch_required
def is_tpu_available():
return _has_tpu
logger = logging.getLogger(__name__)
@dataclass
class PyTorchBenchmarkArguments(BenchmarkArguments):
no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"})
torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
@cached_property
@torch_required
def _setup_devices(self) -> Tuple["torch.device", int]:
logger.info("PyTorch: setting up devices")
if self.no_cuda:
device = torch.device("cpu")
n_gpu = 0
elif is_tpu_available():
device = xm.xla_device()
n_gpu = 0
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
return device, n_gpu
@property
@torch_required
def device_idx(self) -> int:
return torch.cuda.current_device()
@property
@torch_required
def device(self) -> "torch.device":
return self._setup_devices[0]
@property
@torch_required
def n_gpu(self):
return self._setup_devices[1]
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
from dataclasses import dataclass, field
from time import time
from typing import List
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
@dataclass
class BenchmarkArguments:
"""
BenchMarkArguments are arguments we use in our benchmark scripts
**which relate to the training loop itself**.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
models: List[str] = list_field(
default=[],
metadata={
"help": "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version of all available models"
},
)
batch_sizes: List[int] = list_field(
default=[8], metadata={"help": "List of batch sizes for which memory and time performance will be evaluated"}
)
sequence_lengths: List[int] = list_field(
default=[8, 32, 128, 512],
metadata={"help": "List of sequence lengths for which memory and time performance will be evaluated"},
)
no_inference: bool = field(default=False, metadata={"help": "Don't benchmark inference of model"})
training: bool = field(default=False, metadata={"help": "Benchmark training of model"})
verbose: bool = field(default=False, metadata={"help": "Verbose memory tracing"})
no_speed: bool = field(default=False, metadata={"help": "Don't perform speed measurments"})
no_memory: bool = field(default=False, metadata={"help": "Don't perform memory measurments"})
trace_memory_line_by_line: bool = field(default=False, metadata={"help": "Trace memory line by line"})
save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
no_env_print: bool = field(default=False, metadata={"help": "Don't print environment information"})
inference_time_csv_file: str = field(
default=f"inference_time_{round(time())}.csv",
metadata={"help": "CSV filename used if saving time results to csv."},
)
inference_memory_csv_file: str = field(
default=f"inference_memory_{round(time())}.csv",
metadata={"help": "CSV filename used if saving memory results to csv."},
)
train_time_csv_file: str = field(
default=f"train_time_{round(time())}.csv",
metadata={"help": "CSV filename used if saving time results to csv for training."},
)
train_memory_csv_file: str = field(
default=f"train_memory_{round(time())}.csv",
metadata={"help": "CSV filename used if saving memory results to csv for training."},
)
env_info_csv_file: str = field(
default=f"env_info_{round(time())}.csv",
metadata={"help": "CSV filename used if saving environment information."},
)
log_filename: str = field(
default=f"log_{round(time())}.csv",
metadata={"help": "Log filename used if print statements are saved in log."},
)
repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
def to_json_string(self):
"""
Serializes this instance to a JSON string.
"""
return json.dumps(dataclasses.asdict(self), indent=2)
@property
def model_names(self):
return self.models
......@@ -4,18 +4,28 @@ This file is adapted from the AllenNLP library at https://github.com/allenai/all
Copyright by the AllenNLP authors.
"""
import copy
import csv
import linecache
import logging
import os
import platform
import sys
from collections import defaultdict
from abc import ABC, abstractmethod
from collections import defaultdict, namedtuple
from datetime import datetime
from typing import Iterable, List, NamedTuple, Optional, Union
from .file_utils import is_tf_available, is_torch_available
from transformers import AutoConfig, PretrainedConfig
from transformers import __version__ as version
from ..file_utils import is_tf_available, is_torch_available
from .benchmark_args_utils import BenchmarkArguments
if is_torch_available():
from torch.cuda import empty_cache as torch_empty_cache
if is_tf_available():
from tensorflow.python.eager import context as tf_context
......@@ -25,6 +35,10 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_is_memory_tracing_enabled = False
BenchmarkOutput = namedtuple(
"BenchmarkOutput", ["time_inference_result", "memory_inference_result", "time_train_result", "memory_train_result"]
)
def is_memory_tracing_enabled():
global _is_memory_tracing_enabled
......@@ -62,14 +76,14 @@ class UsedMemoryState(NamedTuple):
class Memory(NamedTuple):
""" `Memory` NamedTuple have a single field `bytes` and
you can get a human readable string of the number of bytes by calling `__repr__`
you can get a human readable str of the number of mega bytes by calling `__repr__`
- `byte` (integer): number of bytes,
"""
bytes: int
def __repr__(self) -> str:
return bytes_to_human_readable(self.bytes)
return str(bytes_to_mega_bytes(self.bytes))
class MemoryState(NamedTuple):
......@@ -99,6 +113,7 @@ class MemorySummary(NamedTuple):
sequential: List[MemoryState]
cumulative: List[MemoryState]
current: List[MemoryState]
total: Memory
......@@ -234,10 +249,12 @@ def start_memory_tracing(
# Sum used memory for all GPUs
py3nvml.nvmlInit()
for i in devices:
handle = py3nvml.nvmlDeviceGetHandleByIndex(i)
meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem += meminfo.used
py3nvml.nvmlShutdown()
mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
......@@ -295,8 +312,11 @@ def stop_memory_tracing(
if memory_trace is not None and len(memory_trace) > 1:
memory_diff_trace = []
memory_curr_trace = []
cumulative_memory_dict = defaultdict(lambda: [0, 0, 0])
for (frame, cpu_mem, gpu_mem), (next_frame, next_cpu_mem, next_gpu_mem) in zip(
for ((frame, cpu_mem, gpu_mem), (next_frame, next_cpu_mem, next_gpu_mem),) in zip(
memory_trace[:-1], memory_trace[1:]
):
cpu_mem_inc = next_cpu_mem - cpu_mem
......@@ -307,6 +327,16 @@ def stop_memory_tracing(
frame=frame, cpu=Memory(cpu_mem_inc), gpu=Memory(gpu_mem_inc), cpu_gpu=Memory(cpu_gpu_mem_inc),
)
)
memory_curr_trace.append(
MemoryState(
frame=frame,
cpu=Memory(next_cpu_mem),
gpu=Memory(next_gpu_mem),
cpu_gpu=Memory(next_gpu_mem + next_cpu_mem),
)
)
cumulative_memory_dict[frame][0] += cpu_mem_inc
cumulative_memory_dict[frame][1] += gpu_mem_inc
cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
......@@ -321,21 +351,287 @@ def stop_memory_tracing(
for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
)
memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True)
if ignore_released_memory:
total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace)
else:
total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace)
total_memory = Memory(total_memory)
return MemorySummary(sequential=memory_diff_trace, cumulative=cumulative_memory, total=total_memory)
return MemorySummary(
sequential=memory_diff_trace, cumulative=cumulative_memory, current=memory_curr_trace, total=total_memory,
)
return None
def bytes_to_human_readable(memory_amount: int) -> str:
""" Utility to convert a number of bytes (int) in a human readable string (with units)
def bytes_to_mega_bytes(memory_amount: int) -> int:
""" Utility to convert a number of bytes (int) into a number of mega bytes (int)
"""
return memory_amount >> 20
class Benchmark(ABC):
"""
Benchmarks is a simple but feature-complete benchmarking script
to compare memory and time performance of models in Transformers.
"""
for unit in ["B", "KB", "MB", "GB"]:
if memory_amount > -1024.0 and memory_amount < 1024.0:
return "{:.3f}{}".format(memory_amount, unit)
memory_amount /= 1024.0
return "{:.3f}TB".format(memory_amount)
args: BenchmarkArguments
configs: PretrainedConfig
framework: str
def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig = None):
self.args = args
if configs is None:
self.config_dict = {
model_name: AutoConfig.from_pretrained(model_name) for model_name in self.args.model_names
}
else:
self.config_dict = {model_name: config for model_name, config in zip(self.args.model_names, configs)}
self._print_fn = None
self._framework_version = None
self._environment_info = None
@property
def print_fn(self):
if self._print_fn is None:
if self.args.log_print:
logging.basicConfig(
level=logging.DEBUG,
filename=self.args.log_filename,
filemode="a+",
format="%(asctime)-15s %(levelname)-8s %(message)s",
)
def print_and_log(*args):
logging.info(*args)
print(*args)
self._print_fn = print_and_log
else:
self._print_fn = print
return self._print_fn
@property
def is_gpu(self):
return self.args.n_gpu > 0
@property
@abstractmethod
def framework_version(self):
pass
@abstractmethod
def train(self, model_name, batch_size, sequence_length):
pass
@abstractmethod
def inference(self, model_name, batch_size, sequence_length):
pass
def run(self):
result_dict = {model_name: {} for model_name in self.args.model_names}
inference_result_time = copy.deepcopy(result_dict)
inference_result_memory = copy.deepcopy(result_dict)
train_result_time = copy.deepcopy(result_dict)
train_result_memory = copy.deepcopy(result_dict)
for c, model_name in enumerate(self.args.model_names):
self.print_fn(f"{c + 1} / {len(self.args.model_names)}")
model_dict = {
"bs": self.args.batch_sizes,
"ss": self.args.sequence_lengths,
"result": {i: {} for i in self.args.batch_sizes},
}
inference_result_time[model_name] = copy.deepcopy(model_dict)
inference_result_memory[model_name] = copy.deepcopy(model_dict)
train_result_time[model_name] = copy.deepcopy(model_dict)
train_result_memory[model_name] = copy.deepcopy(model_dict)
for batch_size in self.args.batch_sizes:
for sequence_length in self.args.sequence_lengths:
if not self.args.no_inference:
if not self.args.no_memory:
memory = self.inference(model_name, batch_size, sequence_length, trace_memory=True)
inference_result_memory[model_name]["result"][batch_size][sequence_length] = memory
if not self.args.no_speed:
time = self.inference(model_name, batch_size, sequence_length, trace_memory=False)
inference_result_time[model_name]["result"][batch_size][sequence_length] = time
if self.args.training:
if not self.args.no_memory:
memory = self.train(model_name, batch_size, sequence_length, trace_memory=True)
train_result_memory[model_name]["result"][batch_size][sequence_length] = memory
if not self.args.no_speed:
time = self.inference(model_name, batch_size, sequence_length, trace_memory=False)
train_result_time[model_name]["result"][batch_size][sequence_length] = time
if not self.args.no_inference:
if not self.args.no_speed:
self.print_fn("======= INFERENCE - SPEED - RESULT =======")
self.print_results(inference_result_time)
self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
if not self.args.no_memory:
self.print_fn("======= INFERENCE - MEMORY - RESULT =======")
self.print_results(inference_result_memory)
self.save_to_csv(inference_result_memory, self.args.inference_memory_csv_file)
if self.args.training:
if not self.args.no_speed:
self.print_fn("======= TRAIN - SPEED - RESULT =======")
self.print_results(train_result_time)
self.save_to_csv(train_result_time, self.args.train_time_csv_file)
if not self.args.no_memory:
self.print_fn("======= TRAIN - MEMORY - RESULT =======")
self.print_results(train_result_memory)
self.save_to_csv(train_result_memory, self.args.train_memory_csv_file)
if not self.args.no_env_print:
self.print_fn("\n======== ENVIRONMENT - INFORMATION ========")
self.print_fn(
"\n".join(["- {}: {}".format(prop, val) for prop, val in self.environment_info.items()]) + "\n"
)
if self.args.save_to_csv:
with open(self.args.env_info_csv_file, mode="w", newline="") as csv_file:
writer = csv.writer(csv_file)
for key, value in self.environment_info.items():
writer.writerow([key, value])
return BenchmarkOutput(inference_result_time, inference_result_memory, train_result_time, train_result_memory)
@property
def environment_info(self):
if self._environment_info is None:
info = {}
info["transformers_version"] = version
info["framework"] = self.framework
info["framework_version"] = self.framework_version
info["python_version"] = platform.python_version()
info["system"] = platform.system()
info["cpu"] = platform.processor()
info["architecture"] = platform.architecture()[0]
info["date"] = datetime.date(datetime.now())
info["time"] = datetime.time(datetime.now())
try:
import psutil
except (ImportError):
logger.warning(
"Psutil not installed, we won't log available CPU memory."
"Install psutil (pip install psutil) to log available CPU memory."
)
info["cpu_ram_mb"] = "N/A"
else:
info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
info["use_gpu"] = self.is_gpu
if self.is_gpu:
info["num_gpus"] = self.args.n_gpu
try:
from py3nvml import py3nvml
py3nvml.nvmlInit()
handle = py3nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
except ImportError:
logger.warning(
"py3nvml not installed, we won't log GPU memory usage. "
"Install py3nvml (pip install py3nvml) to log information about GPU."
)
info["gpu"] = "N/A"
info["gpu_ram_mb"] = "N/A"
info["gpu_power_watts"] = "N/A"
info["gpu_performance_state"] = "N/A"
except (OSError, py3nvml.NVMLError):
logger.warning(
"Error while initializing comunication with GPU. " "We won't log information about GPU."
)
info["gpu"] = "N/A"
info["gpu_ram_mb"] = "N/A"
info["gpu_power_watts"] = "N/A"
info["gpu_performance_state"] = "N/A"
py3nvml.nvmlShutdown()
else:
info["gpu"] = py3nvml.nvmlDeviceGetName(handle)
info["gpu_ram_mb"] = bytes_to_mega_bytes(py3nvml.nvmlDeviceGetMemoryInfo(handle).total)
info["gpu_power_watts"] = py3nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000
info["gpu_performance_state"] = py3nvml.nvmlDeviceGetPerformanceState(handle)
py3nvml.nvmlShutdown()
self._environment_info = info
return self._environment_info
def print_results(self, result_dict):
for model_name in self.args.model_names:
self.print_fn("\t" + f"======= MODEL CHECKPOINT: {model_name} =======")
for batch_size in result_dict[model_name]["bs"]:
for sequence_length in result_dict[model_name]["ss"]:
result = result_dict[model_name]["result"][batch_size][sequence_length]
if isinstance(result, float):
self.print_fn(
f"\t\t{model_name}/{batch_size}/{sequence_length}: " f"{(round(1000 * result) / 1000)}s"
)
else:
self.print_fn(f"\t\t{model_name}/{batch_size}/{sequence_length}: " f"{result} MB")
def print_memory_trace_statistics(self, summary: MemorySummary):
self.print_fn(
"\nLine by line memory consumption:\n"
+ "\n".join(
f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.sequential
)
)
self.print_fn(
"\nLines with top memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[:6]
)
)
self.print_fn(
"\nLines with lowest memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[-6:]
)
)
self.print_fn(f"\nTotal memory increase: {summary.total}")
def save_to_csv(self, result_dict, filename):
if not self.args.save_to_csv:
return
self.print_fn("Saving results to csv.")
with open(filename, mode="w") as csv_file:
assert len(self.args.model_names) > 0, "At least 1 model should be defined, but got {}".format(
self.model_names
)
fieldnames = ["model", "batch_size", "sequence_length"]
writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
writer.writeheader()
for model_name in self.args.model_names:
result_dict_model = result_dict[model_name]["result"]
for bs in result_dict_model:
for ss in result_dict_model[bs]:
result_model = result_dict_model[bs][ss]
writer.writerow(
{
"model": model_name,
"batch_size": bs,
"sequence_length": ss,
"result": ("{}" if not isinstance(result_model, float) else "{:.4f}").format(
result_model
),
}
)
......@@ -59,6 +59,7 @@ try:
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
from torch.hub import _get_torch_home
......
......@@ -4,7 +4,7 @@ import sys
from argparse import ArgumentParser
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, NewType, Tuple, Union
from typing import Any, Iterable, List, NewType, Tuple, Union
DataClass = NewType("DataClass", Any)
......@@ -52,9 +52,13 @@ class HfArgumentParser(ArgumentParser):
"We will add compatibility when Python 3.9 is released."
)
typestring = str(field.type)
for x in (int, float, str):
if typestring == f"typing.Union[{x.__name__}, NoneType]":
field.type = x
for prim_type in (int, float, str):
for collection in (List,):
if typestring == f"typing.Union[{collection[prim_type]}, NoneType]":
field.type = collection[prim_type]
if typestring == f"typing.Union[{prim_type.__name__}, NoneType]":
field.type = prim_type
if isinstance(field.type, type) and issubclass(field.type, Enum):
kwargs["choices"] = list(field.type)
kwargs["type"] = field.type
......@@ -65,6 +69,14 @@ class HfArgumentParser(ArgumentParser):
if field.default is True:
field_name = f"--no-{field.name}"
kwargs["dest"] = field.name
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+"
kwargs["type"] = field.type.__args__[0]
assert all(
x == kwargs["type"] for x in field.type.__args__
), "{} cannot be a List of mixed types".format(field.name)
if field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
else:
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING:
......
import os
import tempfile
import unittest
from pathlib import Path
from transformers import GPT2Config, is_torch_available
from .utils import require_torch
if is_torch_available():
from transformers import (
PyTorchBenchmarkArguments,
PyTorchBenchmark,
)
@require_torch
class BenchmarkTest(unittest.TestCase):
def check_results_dict_not_empty(self, results):
for model_result in results.values():
for batch_size, sequence_length in zip(model_result["bs"], model_result["ss"]):
result = model_result["result"][batch_size][sequence_length]
self.assertIsNotNone(result)
def test_inference_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1]
)
benchmark = PyTorchBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_train_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1]
)
benchmark = PyTorchBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)
def test_inference_with_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = GPT2Config.from_pretrained(MODEL_ID)
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1]
)
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_train_with_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = GPT2Config.from_pretrained(MODEL_ID)
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1]
)
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)
def test_save_csv_files(self):
MODEL_ID = "sshleifer/tiny-gpt2"
with tempfile.TemporaryDirectory() as tmp_dir:
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=True,
no_inference=False,
save_to_csv=True,
sequence_lengths=[8],
batch_sizes=[1],
inference_time_csv_file=os.path.join(tmp_dir, "inf_time.csv"),
train_memory_csv_file=os.path.join(tmp_dir, "train_mem.csv"),
inference_memory_csv_file=os.path.join(tmp_dir, "inf_mem.csv"),
train_time_csv_file=os.path.join(tmp_dir, "train_time.csv"),
env_info_csv_file=os.path.join(tmp_dir, "env.csv"),
)
benchmark = PyTorchBenchmark(benchmark_args)
benchmark.run()
self.assertTrue(Path(os.path.join(tmp_dir, "inf_time.csv")).exists())
self.assertTrue(Path(os.path.join(tmp_dir, "train_time.csv")).exists())
self.assertTrue(Path(os.path.join(tmp_dir, "inf_mem.csv")).exists())
self.assertTrue(Path(os.path.join(tmp_dir, "train_mem.csv")).exists())
self.assertTrue(Path(os.path.join(tmp_dir, "env.csv")).exists())
......@@ -3,11 +3,15 @@ import unittest
from argparse import Namespace
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from typing import List, Optional
from transformers import HfArgumentParser, TrainingArguments
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
@dataclass
class BasicExample:
foo: int
......@@ -43,6 +47,16 @@ class OptionalExample:
foo: Optional[int] = None
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
baz: Optional[str] = None
ces: Optional[List[str]] = list_field(default=[])
des: Optional[List[int]] = list_field(default=[])
@dataclass
class ListExample:
foo_int: List[int] = list_field(default=[])
bar_int: List[int] = list_field(default=[1, 2, 3])
foo_str: List[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
class HfArgumentParserTest(unittest.TestCase):
......@@ -101,6 +115,26 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, BasicEnum.titi)
def test_with_list(self):
parser = HfArgumentParser(ListExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo_int", nargs="+", default=[], type=int)
expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int)
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(
args,
Namespace(foo_int=[], bar_int=[1, 2, 3], foo_str=["Hallo", "Bonjour", "Hello"], foo_float=[0.1, 0.2, 0.3]),
)
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
def test_with_optional(self):
parser = HfArgumentParser(OptionalExample)
......@@ -108,13 +142,15 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo", default=None, type=int)
expected.add_argument("--bar", default=None, type=float, help="help message")
expected.add_argument("--baz", default=None, type=str)
expected.add_argument("--ces", nargs="+", default=[], type=str)
expected.add_argument("--des", nargs="+", default=[], type=int)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None))
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42"))
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment