"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6dc4b6f34c26840b82200c1951d176698c55bb0f"
Unverified Commit 702a76ff authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Create an XLA parameter and fix the mixed precision (#7311)

* Create an XLA parameter and fix mixed precision creation

* Fix issue brought by intellisense

* Complete docstring
parent 596342c2
......@@ -531,10 +531,6 @@ class TFTrainer:
tf.summary.experimental.set_step(self.global_step)
if self.args.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
with self.tb_writer.as_default():
tf.summary.text("args", self.args.to_json_string())
......
......@@ -88,7 +88,7 @@ class TFTrainingArguments(TrainingArguments):
tpu_num_cores (:obj:`int`, `optional`):
When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
Wheter to activate the trace to record computation graphs and profiling information or not.
Whether to activate the trace to record computation graphs and profiling information or not.
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not.
......@@ -103,6 +103,8 @@ class TFTrainingArguments(TrainingArguments):
The name of the TPU the process is running on.
run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging.
xla (:obj:`bool`, `optional`):
Whether to activate the XLA compilation or not.
"""
tpu_name: str = field(
......@@ -110,12 +112,23 @@ class TFTrainingArguments(TrainingArguments):
metadata={"help": "Name of TPU"},
)
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
@cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
logger.info("Tensorflow: setting up strategy")
if self.args.xla:
tf.config.optimizer.set_jit(True)
gpus = tf.config.list_physical_devices("GPU")
# Set to float16 at first
if self.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
if self.no_cuda:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
else:
......@@ -128,10 +141,16 @@ class TFTrainingArguments(TrainingArguments):
tpu = None
if tpu:
# Set to bfloat16 in case of TPU
if self.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
elif len(gpus) == 0:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
elif len(gpus) == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment