"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7291ea0bff57a017e71b1ea8ec01ff19da298bf0"
Unverified Commit f82a2a5e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Benchmark] Add benchmarks for TF Training (#5594)

* tf_train

* adapt timing for tpu

* fix timing

* fix timing

* fix timing

* fix timing

* update notebook

* add tests
parent cfbb9829
...@@ -312,8 +312,8 @@ ...@@ -312,8 +312,8 @@
":-- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |\n", ":-- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |\n",
"**Speed - Inference** | ✔ | ✔ | ✔ | ✔ | ✔ | ✘ | ✔ |\n", "**Speed - Inference** | ✔ | ✔ | ✔ | ✔ | ✔ | ✘ | ✔ |\n",
"**Memory - Inference** | ✔ | ✔ | ✔ | ✔ | ✔ | ✘ | ✘ |\n", "**Memory - Inference** | ✔ | ✔ | ✔ | ✔ | ✔ | ✘ | ✘ |\n",
"**Speed - Train** | | ✘ | ✘ | ✘ | ✘ | ✘ | ✘ |\n", "**Speed - Train** | ✔ | ✘ | ✔ | ✘ | ✘ | ✘ | |\n",
"**Memory - Train** | | ✘ | | ✘ | ✘ | ✘ | ✘ |\n", "**Memory - Train** | | ✘ | | ✘ | ✘ | ✘ | ✘ |\n",
"\n", "\n",
"* *eager execution* means that the function is run in the eager execution environment of TensorFlow 2, see [here](https://www.tensorflow.org/guide/eager).\n", "* *eager execution* means that the function is run in the eager execution environment of TensorFlow 2, see [here](https://www.tensorflow.org/guide/eager).\n",
"\n", "\n",
...@@ -321,7 +321,7 @@ ...@@ -321,7 +321,7 @@
"\n", "\n",
"* *FP16* stands for TensorFlow's mixed-precision package and is analogous to PyTorch's FP16 feature, see [here](https://www.tensorflow.org/guide/mixed_precision).\n", "* *FP16* stands for TensorFlow's mixed-precision package and is analogous to PyTorch's FP16 feature, see [here](https://www.tensorflow.org/guide/mixed_precision).\n",
"\n", "\n",
"***Note***: In ~1,2 weeks it will also be possible to benchmark training in TensorFlow.\n", "***Note***: Benchmark training in TensorFlow is not included in v3.0.2, but available in master.\n",
"\n", "\n",
"\n", "\n",
"This notebook will show the user how to use `PyTorchBenchmark` and `TensorFlowBenchmark` for two different scenarios:\n", "This notebook will show the user how to use `PyTorchBenchmark` and `TensorFlowBenchmark` for two different scenarios:\n",
......
...@@ -157,7 +157,7 @@ class PyTorchBenchmark(Benchmark): ...@@ -157,7 +157,7 @@ class PyTorchBenchmark(Benchmark):
else: else:
train_model = model train_model = model
model.eval() model.train()
model.to(self.args.device) model.to(self.args.device)
# encoder-decoder has vocab size saved differently # encoder-decoder has vocab size saved differently
...@@ -175,12 +175,12 @@ class PyTorchBenchmark(Benchmark): ...@@ -175,12 +175,12 @@ class PyTorchBenchmark(Benchmark):
def compute_loss_and_backprob_encoder(): def compute_loss_and_backprob_encoder():
loss = train_model(input_ids, labels=input_ids)[0] loss = train_model(input_ids, labels=input_ids)[0]
loss.backward() loss.backward()
train_model.zero_grad() return loss
def compute_loss_and_backprob_encoder_decoder(): def compute_loss_and_backprob_encoder_decoder():
loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0] loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
loss.backward() loss.backward()
train_model.zero_grad() return loss
_train = ( _train = (
compute_loss_and_backprob_encoder_decoder compute_loss_and_backprob_encoder_decoder
......
...@@ -24,7 +24,13 @@ import timeit ...@@ -24,7 +24,13 @@ import timeit
from functools import wraps from functools import wraps
from typing import Callable, Optional from typing import Callable, Optional
from transformers import TF_MODEL_MAPPING, PretrainedConfig, is_py3nvml_available, is_tf_available from transformers import (
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
PretrainedConfig,
is_py3nvml_available,
is_tf_available,
)
from .benchmark_utils import ( from .benchmark_utils import (
Benchmark, Benchmark,
...@@ -92,10 +98,11 @@ class TensorFlowBenchmark(Benchmark): ...@@ -92,10 +98,11 @@ class TensorFlowBenchmark(Benchmark):
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length) _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_speed(_inference) return self._measure_speed(_inference)
def _train_speed(self, model_name, batch_size, sequence_length): def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
raise NotImplementedError( strategy = self.args.strategy
"Training is currently not really implemented." "Wait for TFTrainer to support CLM and MLM." assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
) _train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_speed(_train)
def _inference_memory( def _inference_memory(
self, model_name: str, batch_size: int, sequence_length: int self, model_name: str, batch_size: int, sequence_length: int
...@@ -108,10 +115,16 @@ class TensorFlowBenchmark(Benchmark): ...@@ -108,10 +115,16 @@ class TensorFlowBenchmark(Benchmark):
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length) _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_memory(_inference) return self._measure_memory(_inference)
def _train_memory(self, model_name, batch_size, sequence_length): def _train_memory(
raise NotImplementedError( self, model_name: str, batch_size: int, sequence_length: int
"Training is currently not really implemented. Wait for TFTrainer to support CLM and MLM." ) -> [Memory, Optional[MemorySummary]]:
) if self.args.is_gpu:
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
strategy = self.args.strategy
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_memory(_train)
def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]: def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name] config = self.config_dict[model_name]
...@@ -149,6 +162,50 @@ class TensorFlowBenchmark(Benchmark): ...@@ -149,6 +162,50 @@ class TensorFlowBenchmark(Benchmark):
return _inference return _inference
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]
assert (
self.args.eager_mode is False
), "Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`."
if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.")
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
transformers_module = __import__("transformers", fromlist=[model_class])
model_cls = getattr(transformers_module, model_class)
model = model_cls(config)
except ImportError:
raise ImportError(
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
# encoder-decoder has vocab size saved differently
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
input_ids = random_input_ids(batch_size, sequence_length, vocab_size)
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
def encoder_decoder_train():
loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids, training=True)[0]
gradients = tf.gradients(loss, model.trainable_variables)
return gradients
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
def encoder_train():
loss = model(input_ids, labels=input_ids, training=True)[0]
gradients = tf.gradients(loss, model.trainable_variables)
return gradients
_train = encoder_decoder_train if config.is_encoder_decoder else encoder_train
return _train
def _measure_speed(self, func) -> float: def _measure_speed(self, func) -> float:
with self.args.strategy.scope(): with self.args.strategy.scope():
try: try:
......
...@@ -100,6 +100,37 @@ class TFBenchmarkTest(unittest.TestCase): ...@@ -100,6 +100,37 @@ class TFBenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.time_inference_result) self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result) self.check_results_dict_not_empty(results.memory_inference_result)
def test_train_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=True,
no_inference=True,
sequence_lengths=[8],
batch_sizes=[1],
no_multi_process=True,
)
benchmark = TensorFlowBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)
def test_train_with_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = AutoConfig.from_pretrained(MODEL_ID)
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=True,
no_inference=True,
sequence_lengths=[8],
batch_sizes=[1],
no_multi_process=True,
)
benchmark = TensorFlowBenchmark(benchmark_args, [config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)
def test_inference_encoder_decoder_with_configs(self): def test_inference_encoder_decoder_with_configs(self):
MODEL_ID = "patrickvonplaten/t5-tiny-random" MODEL_ID = "patrickvonplaten/t5-tiny-random"
config = AutoConfig.from_pretrained(MODEL_ID) config = AutoConfig.from_pretrained(MODEL_ID)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment