Unverified Commit 6e058e84 authored by Yanming Wang's avatar Yanming Wang Committed by GitHub
Browse files

Enable AMP for xla:gpu device in trainer class (#15022)

* Multiple fixes of trainer class with XLA GPU

* Make fp16 valid for xla:gpu

* Add mark_step in should_log to reduce compilation overhead
parent 3fc221d0
...@@ -449,6 +449,10 @@ class Trainer: ...@@ -449,6 +449,10 @@ class Trainer:
self.scaler = smp.amp.GradScaler() self.scaler = smp.amp.GradScaler()
elif self.sharded_ddp is not None: elif self.sharded_ddp is not None:
self.scaler = ShardedGradScaler() self.scaler = ShardedGradScaler()
elif is_torch_tpu_available():
from torch_xla.amp import GradScaler
self.scaler = GradScaler()
else: else:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
else: else:
...@@ -1386,6 +1390,10 @@ class Trainer: ...@@ -1386,6 +1390,10 @@ class Trainer:
# deepspeed does its own clipping # deepspeed does its own clipping
if self.do_grad_scaling: if self.do_grad_scaling:
# Reduce gradients first for XLA
if is_torch_tpu_available():
gradients = xm._fetch_gradients(self.optimizer)
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
# AMP: gradients need unscaling # AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
...@@ -1407,7 +1415,11 @@ class Trainer: ...@@ -1407,7 +1415,11 @@ class Trainer:
if self.deepspeed: if self.deepspeed:
pass # called outside the loop pass # called outside the loop
elif is_torch_tpu_available(): elif is_torch_tpu_available():
xm.optimizer_step(self.optimizer) if self.do_grad_scaling:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
xm.optimizer_step(self.optimizer)
elif self.do_grad_scaling: elif self.do_grad_scaling:
scale_before = self.scaler.get_scale() scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer) self.scaler.step(self.optimizer)
...@@ -1528,6 +1540,9 @@ class Trainer: ...@@ -1528,6 +1540,9 @@ class Trainer:
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log: if self.control.should_log:
if is_torch_tpu_available():
xm.mark_step()
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
# all_gather + mean() to get average loss over all processes # all_gather + mean() to get average loss over all processes
...@@ -2362,6 +2377,9 @@ class Trainer: ...@@ -2362,6 +2377,9 @@ class Trainer:
# Prediction step # Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
if is_torch_tpu_available():
xm.mark_step()
# Update containers on host # Update containers on host
if loss is not None: if loss is not None:
losses = self._nested_gather(loss.repeat(batch_size)) losses = self._nested_gather(loss.repeat(batch_size))
......
...@@ -838,7 +838,8 @@ class TrainingArguments: ...@@ -838,7 +838,8 @@ class TrainingArguments:
if ( if (
is_torch_available() is_torch_available()
and self.device.type != "cuda" and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval) and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval)
): ):
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment