"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "91758e399f8c4bf81820a8af6a257682ccea0223"
Unverified Commit 6e058e84 authored by Yanming Wang's avatar Yanming Wang Committed by GitHub
Browse files

Enable AMP for xla:gpu device in trainer class (#15022)

* Multiple fixes of trainer class with XLA GPU

* Make fp16 valid for xla:gpu

* Add mark_step in should_log to reduce compilation overhead
parent 3fc221d0
......@@ -449,6 +449,10 @@ class Trainer:
self.scaler = smp.amp.GradScaler()
elif self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
elif is_torch_tpu_available():
from torch_xla.amp import GradScaler
self.scaler = GradScaler()
else:
self.scaler = torch.cuda.amp.GradScaler()
else:
......@@ -1386,6 +1390,10 @@ class Trainer:
# deepspeed does its own clipping
if self.do_grad_scaling:
# Reduce gradients first for XLA
if is_torch_tpu_available():
gradients = xm._fetch_gradients(self.optimizer)
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
......@@ -1407,7 +1415,11 @@ class Trainer:
if self.deepspeed:
pass # called outside the loop
elif is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
if self.do_grad_scaling:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
xm.optimizer_step(self.optimizer)
elif self.do_grad_scaling:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
......@@ -1528,6 +1540,9 @@ class Trainer:
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
if is_torch_tpu_available():
xm.mark_step()
logs: Dict[str, float] = {}
# all_gather + mean() to get average loss over all processes
......@@ -2362,6 +2377,9 @@ class Trainer:
# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
if is_torch_tpu_available():
xm.mark_step()
# Update containers on host
if loss is not None:
losses = self._nested_gather(loss.repeat(batch_size))
......
......@@ -838,7 +838,8 @@ class TrainingArguments:
if (
is_torch_available()
and self.device.type != "cuda"
and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval)
):
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment