Unverified Commit 2cfb947f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Benchmark] add tpu and torchscipt for benchmark (#4850)



* add tpu and torchscipt for benchmark

* fix name in tests

* "fix email"

* make style

* better log message for tpu

* add more print and info for tpu

* allow possibility to print tpu metrics

* correct cpu usage

* fix test for non-install

* remove bugus file

* include psutil in testing

* run a couple of times before tracing in torchscript

* do not allow tpu memory tracing for now

* make style

* add torchscript to env

* better name for torch tpu
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent f0340b30
...@@ -84,7 +84,7 @@ extras["torch"] = ["torch"] ...@@ -84,7 +84,7 @@ extras["torch"] = ["torch"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"] extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
extras["all"] = extras["serving"] + ["tensorflow", "torch"] extras["all"] = extras["serving"] + ["tensorflow", "torch"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator"] extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil"]
extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme"] extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme"]
extras["quality"] = [ extras["quality"] = [
"black", "black",
......
...@@ -78,6 +78,7 @@ from .file_utils import ( ...@@ -78,6 +78,7 @@ from .file_utils import (
cached_path, cached_path,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_torch_tpu_available,
) )
from .hf_argparser import HfArgumentParser from .hf_argparser import HfArgumentParser
......
...@@ -19,12 +19,17 @@ ...@@ -19,12 +19,17 @@
import logging import logging
import os
import timeit import timeit
from transformers import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, PretrainedConfig, is_torch_available from transformers import (
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
PretrainedConfig,
is_torch_available,
is_torch_tpu_available,
)
from .benchmark_utils import Benchmark, Memory, start_memory_tracing, stop_memory_tracing from .benchmark_utils import Benchmark, Memory, measure_peak_memory_cpu, start_memory_tracing, stop_memory_tracing
if is_torch_available(): if is_torch_available():
...@@ -48,6 +53,10 @@ class PyTorchBenchmark(Benchmark): ...@@ -48,6 +53,10 @@ class PyTorchBenchmark(Benchmark):
def train(self, model_name, batch_size, sequence_length, trace_memory=False): def train(self, model_name, batch_size, sequence_length, trace_memory=False):
try: try:
config = self.config_dict[model_name] config = self.config_dict[model_name]
if self.args.torchscript:
config.torchscript = True
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
model.to(self.args.device) model.to(self.args.device)
model.train() model.train()
...@@ -58,15 +67,20 @@ class PyTorchBenchmark(Benchmark): ...@@ -58,15 +67,20 @@ class PyTorchBenchmark(Benchmark):
vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
) )
if self.args.torchscript:
raise NotImplementedError("Training for torchscript is currently not implemented")
else:
train_model = model
def compute_loss_and_backprob_encoder(): def compute_loss_and_backprob_encoder():
loss = model(input_ids, labels=input_ids)[0] loss = train_model(input_ids, labels=input_ids)[0]
loss.backward() loss.backward()
model.zero_grad() train_model.zero_grad()
def compute_loss_and_backprob_encoder_decoder(): def compute_loss_and_backprob_encoder_decoder():
loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0] loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
loss.backward() loss.backward()
model.zero_grad() train_model.zero_grad()
_train = ( _train = (
compute_loss_and_backprob_encoder_decoder compute_loss_and_backprob_encoder_decoder
...@@ -79,6 +93,7 @@ class PyTorchBenchmark(Benchmark): ...@@ -79,6 +93,7 @@ class PyTorchBenchmark(Benchmark):
trace = start_memory_tracing("transformers") trace = start_memory_tracing("transformers")
if self.args.n_gpu > 0: if self.args.n_gpu > 0:
# gpu
# clear gpu cache # clear gpu cache
torch.cuda.empty_cache() torch.cuda.empty_cache()
if hasattr(torch.cuda, "max_memory_reserved"): if hasattr(torch.cuda, "max_memory_reserved"):
...@@ -89,8 +104,17 @@ class PyTorchBenchmark(Benchmark): ...@@ -89,8 +104,17 @@ class PyTorchBenchmark(Benchmark):
) )
torch.cuda.reset_max_memory_cached() torch.cuda.reset_max_memory_cached()
# calculate loss and do backpropagation # calculate loss and do backpropagation
_train() _train()
elif not self.args.no_tpu and is_torch_tpu_available():
# tpu
raise NotImplementedError(
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`"
)
else:
# cpu
memory_bytes = measure_peak_memory_cpu(_train)
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
if self.args.trace_memory_line_by_line: if self.args.trace_memory_line_by_line:
summary = stop_memory_tracing(trace) summary = stop_memory_tracing(trace)
...@@ -107,40 +131,47 @@ class PyTorchBenchmark(Benchmark): ...@@ -107,40 +131,47 @@ class PyTorchBenchmark(Benchmark):
) )
memory = Memory(torch.cuda.max_memory_cached()) memory = Memory(torch.cuda.max_memory_cached())
memory = Memory(torch.cuda.max_memory_reserved()) memory = Memory(torch.cuda.max_memory_reserved())
else:
# cpu
try:
import psutil
except (ImportError):
logger.warning(
"Psutil not installed, we won't log CPU memory usage. "
"Install psutil (pip install psutil) to use CPU memory tracing."
)
memory = "N/A"
else:
process = psutil.Process(os.getpid())
memory = Memory(process.memory_info().rss)
return memory, summary return memory, summary
else: else:
if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript:
# run additional 10 times to stabilize compilation for tpu and torchscript
logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
timeit.repeat(
_train, repeat=1, number=5,
)
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(_train, repeat=self.args.repeat, number=10,) runtimes = timeit.repeat(_train, repeat=self.args.repeat, number=10,)
if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics:
import torch_xla.debug.metrics as met
self.print_fn(met.metrics_report())
return min(runtimes) / 10.0 return min(runtimes) / 10.0
except RuntimeError as e: except RuntimeError as e:
self.print_fn("Doesn't fit on GPU. {}".format(e)) self.print_fn("Doesn't fit on GPU. {}".format(e))
return "N/A" if trace_memory:
return "N/A", None
else:
return "N/A"
def inference(self, model_name, batch_size, sequence_length, trace_memory=False): def inference(self, model_name, batch_size, sequence_length, trace_memory=False):
try: try:
config = self.config_dict[model_name] config = self.config_dict[model_name]
model = None
if self.args.torchscript:
config.torchscript = True
if self.args.with_lm_head: if self.args.with_lm_head:
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
else: else:
model = MODEL_MAPPING[config.__class__](config) model = MODEL_MAPPING[config.__class__](config)
model.to(self.args.device)
model.eval() model.eval()
model.to(self.args.device)
# encoder-decoder has vocab size saved differently # encoder-decoder has vocab size saved differently
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
...@@ -149,11 +180,22 @@ class PyTorchBenchmark(Benchmark): ...@@ -149,11 +180,22 @@ class PyTorchBenchmark(Benchmark):
vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
) )
if self.args.torchscript:
with torch.no_grad():
if config.is_encoder_decoder:
raise NotImplementedError("Torchscript is currently not supported for EncoderDecoder models")
else:
inference_model = torch.jit.trace(model, input_ids)
else:
inference_model = model
def encoder_decoder_forward(): def encoder_decoder_forward():
model(input_ids, decoder_input_ids=input_ids) with torch.no_grad():
inference_model(input_ids, decoder_input_ids=input_ids)
def encoder_forward(): def encoder_forward():
model(input_ids) with torch.no_grad():
inference_model(input_ids)
_forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward _forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
...@@ -162,6 +204,7 @@ class PyTorchBenchmark(Benchmark): ...@@ -162,6 +204,7 @@ class PyTorchBenchmark(Benchmark):
trace = start_memory_tracing("transformers") trace = start_memory_tracing("transformers")
if self.args.n_gpu > 0: if self.args.n_gpu > 0:
# gpu
# clear gpu cache # clear gpu cache
torch.cuda.empty_cache() torch.cuda.empty_cache()
if hasattr(torch.cuda, "max_memory_reserved"): if hasattr(torch.cuda, "max_memory_reserved"):
...@@ -172,7 +215,17 @@ class PyTorchBenchmark(Benchmark): ...@@ -172,7 +215,17 @@ class PyTorchBenchmark(Benchmark):
) )
torch.cuda.reset_max_memory_cached() torch.cuda.reset_max_memory_cached()
_forward() # run forward
_forward()
elif not self.args.no_tpu and is_torch_tpu_available():
# tpu
raise NotImplementedError(
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`"
)
else:
# cpu
memory_bytes = measure_peak_memory_cpu(_forward)
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
if self.args.trace_memory_line_by_line: if self.args.trace_memory_line_by_line:
summary = stop_memory_tracing(trace) summary = stop_memory_tracing(trace)
...@@ -188,26 +241,30 @@ class PyTorchBenchmark(Benchmark): ...@@ -188,26 +241,30 @@ class PyTorchBenchmark(Benchmark):
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage" "Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
) )
memory = Memory(torch.cuda.max_memory_cached()) memory = Memory(torch.cuda.max_memory_cached())
else:
# cpu
try:
import psutil
except (ImportError):
logger.warning(
"Psutil not installed, we won't log CPU memory usage. "
"Install psutil (pip install psutil) to use CPU memory tracing."
)
memory = "N/A"
else:
process = psutil.Process(os.getpid())
memory = Memory(process.memory_info().rss)
return memory, summary return memory, summary
else: else:
if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript:
# run additional 10 times to stabilize compilation for tpu and torchscript
logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
timeit.repeat(
_forward, repeat=1, number=5,
)
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(_forward, repeat=self.args.repeat, number=10,) runtimes = timeit.repeat(_forward, repeat=self.args.repeat, number=10,)
if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics:
import torch_xla.debug.metrics as met
self.print_fn(met.metrics_report())
return min(runtimes) / 10.0 return min(runtimes) / 10.0
except RuntimeError as e: except RuntimeError as e:
self.print_fn("Doesn't fit on GPU. {}".format(e)) self.print_fn("Doesn't fit on GPU. {}".format(e))
return "N/A" if trace_memory:
return "N/A", None
else:
return "N/A"
...@@ -18,25 +18,16 @@ import logging ...@@ -18,25 +18,16 @@ import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
from ..file_utils import cached_property, is_torch_available, torch_required from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
if is_torch_available(): if is_torch_available():
import torch import torch
try: if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
_has_tpu = True
except ImportError:
_has_tpu = False
@torch_required
def is_tpu_available():
return _has_tpu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,7 +36,9 @@ logger = logging.getLogger(__name__) ...@@ -45,7 +36,9 @@ logger = logging.getLogger(__name__)
class PyTorchBenchmarkArguments(BenchmarkArguments): class PyTorchBenchmarkArguments(BenchmarkArguments):
no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"}) no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"})
torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"}) torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
no_tpu: bool = field(default=False, metadata={"help": "Whether to run on available tpu devices"})
fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."}) fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
tpu_print_metrics: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
@cached_property @cached_property
@torch_required @torch_required
...@@ -54,7 +47,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments): ...@@ -54,7 +47,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
if self.no_cuda: if self.no_cuda:
device = torch.device("cpu") device = torch.device("cpu")
n_gpu = 0 n_gpu = 0
elif is_tpu_available(): elif is_torch_tpu_available():
device = xm.xla_device() device = xm.xla_device()
n_gpu = 0 n_gpu = 0
else: else:
......
...@@ -14,12 +14,15 @@ import sys ...@@ -14,12 +14,15 @@ import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from datetime import datetime from datetime import datetime
from typing import Iterable, List, NamedTuple, Optional, Union from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from signal import SIGKILL
from typing import Callable, Iterable, List, NamedTuple, Optional, Union
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from transformers import __version__ as version from transformers import __version__ as version
from ..file_utils import is_tf_available, is_torch_available from ..file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
...@@ -128,6 +131,127 @@ class MemorySummary(NamedTuple): ...@@ -128,6 +131,127 @@ class MemorySummary(NamedTuple):
MemoryTrace = List[UsedMemoryState] MemoryTrace = List[UsedMemoryState]
def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5) -> int:
"""
measures peak cpu memory consumption of a given `function`
running the function for at least interval seconds
and at most 20 * interval seconds.
This function is heavily inspired by: `memory_usage`
of the package `memory_profiler`: https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239
Args:
- `function`: (`callable`): function() -> ...
function without any arguments to measure for which to measure the peak memory
- `interval`: (`float`)
interval in second for which to measure the memory usage
Returns:
- `max_memory`: (`int`)
cosumed memory peak in Bytes
"""
try:
import psutil
except (ImportError):
logger.warning(
"Psutil not installed, we won't log CPU memory usage. "
"Install Psutil (pip install psutil) to use CPU memory tracing."
)
max_memory = "N/A"
else:
def _get_memory(process_id: int) -> int:
"""
measures current cpu memory usage of a given `process_id`
Args:
- `process_id`: (`int`)
process_id for which to measure memory
Returns
- `memory`: (`int`)
cosumed memory in Bytes
"""
process = psutil.Process(process_id)
try:
meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info"
memory = getattr(process, meminfo_attr)()[0]
except psutil.AccessDenied:
raise ValueError("Error with Psutil.")
return memory
class MemoryMeasureProcess(Process):
"""
`MemoryMeasureProcess` inherits from `Process` and overwrites
its `run()` method. Used to measure the memory usage of a process
"""
def __init__(self, process_id: int, child_connection: Connection, interval: float):
super().__init__()
self.process_id = process_id
self.interval = interval
self.connection = child_connection
self.num_measurements = 1
self.mem_usage = _get_memory(process_id)
def run(self):
self.connection.send(0)
stop = False
while True:
self.mem_usage = max(self.mem_usage, _get_memory(self.process_id))
self.num_measurements += 1
if stop:
break
stop = self.connection.poll(self.interval)
# send results to parent pipe
self.connection.send(self.mem_usage)
self.connection.send(self.num_measurements)
while True:
# create child, parent connection
child_connection, parent_connection = Pipe()
# instantiate process
mem_process = MemoryMeasureProcess(os.getpid(), child_connection, interval)
mem_process.start()
# wait until we get memory
parent_connection.recv()
try:
# execute function
function()
# start parent connection
parent_connection.send(0)
# receive memory and num measurements
max_memory = parent_connection.recv()
num_measurements = parent_connection.recv()
except Exception:
# kill process in a clean way
parent = psutil.Process(os.getpid())
for child in parent.children(recursive=True):
os.kill(child.pid, SIGKILL)
mem_process.join(0)
raise RuntimeError("Process killed. Error in Process")
# run process at least 20 * interval or until it finishes
mem_process.join(20 * interval)
if (num_measurements > 4) or (interval < 1e-6):
break
# reduce interval
interval /= 10
return max_memory
def start_memory_tracing( def start_memory_tracing(
modules_to_trace: Optional[Union[str, Iterable[str]]] = None, modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None, modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
...@@ -424,6 +548,10 @@ class Benchmark(ABC): ...@@ -424,6 +548,10 @@ class Benchmark(ABC):
def is_gpu(self): def is_gpu(self):
return self.args.n_gpu > 0 return self.args.n_gpu > 0
@property
def is_tpu(self):
return is_torch_tpu_available() and not self.args.no_tpu
@property @property
@abstractmethod @abstractmethod
def framework_version(self): def framework_version(self):
...@@ -486,6 +614,10 @@ class Benchmark(ABC): ...@@ -486,6 +614,10 @@ class Benchmark(ABC):
self.print_fn("======= INFERENCE - SPEED - RESULT =======") self.print_fn("======= INFERENCE - SPEED - RESULT =======")
self.print_results(inference_result_time) self.print_results(inference_result_time)
self.save_to_csv(inference_result_time, self.args.inference_time_csv_file) self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
if self.is_tpu:
self.print_fn(
"TPU was used for inference. Note that the time after compilation stabilized (after ~10 inferences model.forward(..) calls) was measured."
)
if not self.args.no_memory: if not self.args.no_memory:
self.print_fn("======= INFERENCE - MEMORY - RESULT =======") self.print_fn("======= INFERENCE - MEMORY - RESULT =======")
...@@ -501,6 +633,10 @@ class Benchmark(ABC): ...@@ -501,6 +633,10 @@ class Benchmark(ABC):
self.print_fn("======= TRAIN - SPEED - RESULT =======") self.print_fn("======= TRAIN - SPEED - RESULT =======")
self.print_results(train_result_time) self.print_results(train_result_time)
self.save_to_csv(train_result_time, self.args.train_time_csv_file) self.save_to_csv(train_result_time, self.args.train_time_csv_file)
if self.is_tpu:
self.print_fn(
"TPU was used for training. Note that the time after compilation stabilized (after ~10 train loss=model.forward(...) + loss.backward() calls) was measured."
)
if not self.args.no_memory: if not self.args.no_memory:
self.print_fn("======= TRAIN - MEMORY - RESULT =======") self.print_fn("======= TRAIN - MEMORY - RESULT =======")
...@@ -538,6 +674,8 @@ class Benchmark(ABC): ...@@ -538,6 +674,8 @@ class Benchmark(ABC):
info = {} info = {}
info["transformers_version"] = version info["transformers_version"] = version
info["framework"] = self.framework info["framework"] = self.framework
if self.framework == "PyTorch":
info["use_torchscript"] = self.args.torchscript
info["framework_version"] = self.framework_version info["framework_version"] = self.framework_version
info["python_version"] = platform.python_version() info["python_version"] = platform.python_version()
info["system"] = platform.system() info["system"] = platform.system()
...@@ -590,6 +728,10 @@ class Benchmark(ABC): ...@@ -590,6 +728,10 @@ class Benchmark(ABC):
info["gpu_performance_state"] = py3nvml.nvmlDeviceGetPerformanceState(handle) info["gpu_performance_state"] = py3nvml.nvmlDeviceGetPerformanceState(handle)
py3nvml.nvmlShutdown() py3nvml.nvmlShutdown()
info["use_tpu"] = self.is_tpu
# TODO(PVP): See if we can add more information about TPU
# see: https://github.com/pytorch/xla/issues/2180
self._environment_info = info self._environment_info = info
return self._environment_info return self._environment_info
......
...@@ -68,6 +68,21 @@ except ImportError: ...@@ -68,6 +68,21 @@ except ImportError:
torch_cache_home = os.path.expanduser( torch_cache_home = os.path.expanduser(
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
) )
try:
import torch_xla.core.xla_model as xm
tpu_device = xm.xla_device()
if _torch_available:
_torch_tpu_available = True # pylint: disable=
else:
_torch_tpu_available = False
except ImportError:
_torch_tpu_available = False
default_cache_path = os.path.join(torch_cache_home, "transformers") default_cache_path = os.path.join(torch_cache_home, "transformers")
...@@ -98,6 +113,10 @@ def is_tf_available(): ...@@ -98,6 +113,10 @@ def is_tf_available():
return _tf_available return _tf_available
def is_torch_tpu_available():
return _torch_tpu_available
def add_start_docstrings(*docstr): def add_start_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
......
...@@ -23,7 +23,7 @@ from .data.data_collator import DataCollator, DefaultDataCollator ...@@ -23,7 +23,7 @@ from .data.data_collator import DataCollator, DefaultDataCollator
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
from .training_args import TrainingArguments, is_tpu_available from .training_args import TrainingArguments, is_torch_tpu_available
try: try:
...@@ -38,7 +38,7 @@ def is_apex_available(): ...@@ -38,7 +38,7 @@ def is_apex_available():
return _has_apex return _has_apex
if is_tpu_available(): if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.parallel_loader as pl
...@@ -218,7 +218,7 @@ class Trainer: ...@@ -218,7 +218,7 @@ class Trainer:
# Create output directory if needed # Create output directory if needed
if self.is_world_master(): if self.is_world_master():
os.makedirs(self.args.output_dir, exist_ok=True) os.makedirs(self.args.output_dir, exist_ok=True)
if is_tpu_available(): if is_torch_tpu_available():
# Set an xla_device flag on the model's config. # Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future. # We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True self.model.config.xla_device = True
...@@ -226,7 +226,7 @@ class Trainer: ...@@ -226,7 +226,7 @@ class Trainer:
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None: if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.") raise ValueError("Trainer: training requires a train_dataset.")
if is_tpu_available(): if is_torch_tpu_available():
train_sampler = get_tpu_sampler(self.train_dataset) train_sampler = get_tpu_sampler(self.train_dataset)
else: else:
train_sampler = ( train_sampler = (
...@@ -251,7 +251,7 @@ class Trainer: ...@@ -251,7 +251,7 @@ class Trainer:
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if is_tpu_available(): if is_torch_tpu_available():
sampler = SequentialDistributedSampler( sampler = SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
) )
...@@ -272,7 +272,7 @@ class Trainer: ...@@ -272,7 +272,7 @@ class Trainer:
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
# We use the same batch_size as for eval. # We use the same batch_size as for eval.
if is_tpu_available(): if is_torch_tpu_available():
sampler = SequentialDistributedSampler( sampler = SequentialDistributedSampler(
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
) )
...@@ -407,7 +407,7 @@ class Trainer: ...@@ -407,7 +407,7 @@ class Trainer:
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
# Train! # Train!
if is_tpu_available(): if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
else: else:
total_train_batch_size = ( total_train_batch_size = (
...@@ -455,7 +455,7 @@ class Trainer: ...@@ -455,7 +455,7 @@ class Trainer:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
if is_tpu_available(): if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device self.args.device
) )
...@@ -482,7 +482,7 @@ class Trainer: ...@@ -482,7 +482,7 @@ class Trainer:
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
if is_tpu_available(): if is_torch_tpu_available():
xm.optimizer_step(optimizer) xm.optimizer_step(optimizer)
else: else:
optimizer.step() optimizer.step()
...@@ -525,7 +525,7 @@ class Trainer: ...@@ -525,7 +525,7 @@ class Trainer:
if self.is_world_master(): if self.is_world_master():
self._rotate_checkpoints() self._rotate_checkpoints()
if is_tpu_available(): if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states") xm.rendezvous("saving_optimizer_states")
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
...@@ -588,7 +588,7 @@ class Trainer: ...@@ -588,7 +588,7 @@ class Trainer:
return loss.item() return loss.item()
def is_local_master(self) -> bool: def is_local_master(self) -> bool:
if is_tpu_available(): if is_torch_tpu_available():
return xm.is_master_ordinal(local=True) return xm.is_master_ordinal(local=True)
else: else:
return self.args.local_rank in [-1, 0] return self.args.local_rank in [-1, 0]
...@@ -598,7 +598,7 @@ class Trainer: ...@@ -598,7 +598,7 @@ class Trainer:
This will be True only in one process, even in distributed mode, This will be True only in one process, even in distributed mode,
even when training on multiple machines. even when training on multiple machines.
""" """
if is_tpu_available(): if is_torch_tpu_available():
return xm.is_master_ordinal(local=False) return xm.is_master_ordinal(local=False)
else: else:
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0 return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
...@@ -611,7 +611,7 @@ class Trainer: ...@@ -611,7 +611,7 @@ class Trainer:
Will only save from the world_master process (unless in TPUs). Will only save from the world_master process (unless in TPUs).
""" """
if is_tpu_available(): if is_torch_tpu_available():
self._save_tpu(output_dir) self._save_tpu(output_dir)
elif self.is_world_master(): elif self.is_world_master():
self._save(output_dir) self._save(output_dir)
...@@ -746,7 +746,7 @@ class Trainer: ...@@ -746,7 +746,7 @@ class Trainer:
label_ids: torch.Tensor = None label_ids: torch.Tensor = None
model.eval() model.eval()
if is_tpu_available(): if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
for inputs in tqdm(dataloader, desc=description): for inputs in tqdm(dataloader, desc=description):
...@@ -780,7 +780,7 @@ class Trainer: ...@@ -780,7 +780,7 @@ class Trainer:
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None: if label_ids is not None:
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_tpu_available(): elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None: if preds is not None:
preds = xm.mesh_reduce("eval_preds", preds, torch.cat) preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
......
...@@ -5,25 +5,15 @@ import os ...@@ -5,25 +5,15 @@ import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from .file_utils import cached_property, is_torch_available, torch_required from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
if is_torch_available(): if is_torch_available():
import torch import torch
if is_torch_tpu_available():
try:
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
_has_tpu = True
except ImportError:
_has_tpu = False
@torch_required
def is_tpu_available():
return _has_tpu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -176,7 +166,7 @@ class TrainingArguments: ...@@ -176,7 +166,7 @@ class TrainingArguments:
if self.no_cuda: if self.no_cuda:
device = torch.device("cpu") device = torch.device("cpu")
n_gpu = 0 n_gpu = 0
elif is_tpu_available(): elif is_torch_tpu_available():
device = xm.xla_device() device = xm.xla_device()
n_gpu = 0 n_gpu = 0
elif self.local_rank == -1: elif self.local_rank == -1:
......
...@@ -33,6 +33,21 @@ class BenchmarkTest(unittest.TestCase): ...@@ -33,6 +33,21 @@ class BenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.time_inference_result) self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result) self.check_results_dict_not_empty(results.memory_inference_result)
def test_inference_torchscript(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
torchscript=True,
sequence_lengths=[8],
batch_sizes=[1],
)
benchmark = PyTorchBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_train_no_configs(self): def test_train_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2" MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments( benchmark_args = PyTorchBenchmarkArguments(
...@@ -76,6 +91,22 @@ class BenchmarkTest(unittest.TestCase): ...@@ -76,6 +91,22 @@ class BenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.time_train_result) self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result) self.check_results_dict_not_empty(results.memory_train_result)
def test_train_with_configs_torchscript(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = AutoConfig.from_pretrained(MODEL_ID)
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=True,
no_inference=True,
torchscript=True,
sequence_lengths=[8],
batch_sizes=[1],
)
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)
def test_train_encoder_decoder_with_configs(self): def test_train_encoder_decoder_with_configs(self):
MODEL_ID = "sshleifer/tinier_bart" MODEL_ID = "sshleifer/tinier_bart"
config = AutoConfig.from_pretrained(MODEL_ID) config = AutoConfig.from_pretrained(MODEL_ID)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment