Unverified Commit 702a76ff authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Create an XLA parameter and fix the mixed precision (#7311)

* Create an XLA parameter and fix mixed precision creation

* Fix issue brought by intellisense

* Complete docstring
parent 596342c2
...@@ -531,10 +531,6 @@ class TFTrainer: ...@@ -531,10 +531,6 @@ class TFTrainer:
tf.summary.experimental.set_step(self.global_step) tf.summary.experimental.set_step(self.global_step)
if self.args.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
with self.tb_writer.as_default(): with self.tb_writer.as_default():
tf.summary.text("args", self.args.to_json_string()) tf.summary.text("args", self.args.to_json_string())
......
...@@ -88,7 +88,7 @@ class TFTrainingArguments(TrainingArguments): ...@@ -88,7 +88,7 @@ class TFTrainingArguments(TrainingArguments):
tpu_num_cores (:obj:`int`, `optional`): tpu_num_cores (:obj:`int`, `optional`):
When training on TPU, the mumber of TPU cores (automatically passed by launcher script). When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
debug (:obj:`bool`, `optional`, defaults to :obj:`False`): debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
Wheter to activate the trace to record computation graphs and profiling information or not. Whether to activate the trace to record computation graphs and profiling information or not.
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not. or not.
...@@ -103,6 +103,8 @@ class TFTrainingArguments(TrainingArguments): ...@@ -103,6 +103,8 @@ class TFTrainingArguments(TrainingArguments):
The name of the TPU the process is running on. The name of the TPU the process is running on.
run_name (:obj:`str`, `optional`): run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging. A descriptor for the run. Notably used for wandb logging.
xla (:obj:`bool`, `optional`):
Whether to activate the XLA compilation or not.
""" """
tpu_name: str = field( tpu_name: str = field(
...@@ -110,12 +112,23 @@ class TFTrainingArguments(TrainingArguments): ...@@ -110,12 +112,23 @@ class TFTrainingArguments(TrainingArguments):
metadata={"help": "Name of TPU"}, metadata={"help": "Name of TPU"},
) )
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
@cached_property @cached_property
@tf_required @tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]: def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
logger.info("Tensorflow: setting up strategy") logger.info("Tensorflow: setting up strategy")
if self.args.xla:
tf.config.optimizer.set_jit(True)
gpus = tf.config.list_physical_devices("GPU") gpus = tf.config.list_physical_devices("GPU")
# Set to float16 at first
if self.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
if self.no_cuda: if self.no_cuda:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
else: else:
...@@ -128,10 +141,16 @@ class TFTrainingArguments(TrainingArguments): ...@@ -128,10 +141,16 @@ class TFTrainingArguments(TrainingArguments):
tpu = None tpu = None
if tpu: if tpu:
# Set to bfloat16 in case of TPU
if self.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.config.experimental_connect_to_cluster(tpu) tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu) tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu) strategy = tf.distribute.experimental.TPUStrategy(tpu)
elif len(gpus) == 0: elif len(gpus) == 0:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
elif len(gpus) == 1: elif len(gpus) == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment