Commit 030646c0 authored by Zhaoheng Ni's avatar Zhaoheng Ni
Browse files

Enable mixed precision training for hubert_pretrain_model (#2854)

Summary:
address https://github.com/pytorch/audio/issues/2847

In mixed precision training, the dtype of `mask_embedding` is **not** converted to fp16 automatically. This PR addresses the issue by changing the dtype of `mask_embedding` to `x` to enable mixed precision training.

Pull Request resolved: https://github.com/pytorch/audio/pull/2854

Reviewed By: carolineechen

Differential Revision: D41343486

Pulled By: nateanl

fbshipit-source-id: 4a5cbb429ff8ba5d3c439a3d5acb5094f66bf705
parent 980528e9
......@@ -247,7 +247,7 @@ class HuBERTPreTrainModule(LightningModule):
"""
opt = self.optimizers()
opt.zero_grad()
with torch.cuda.amp.autocast(enabled=False):
with torch.cuda.amp.autocast(enabled=True):
loss, num_frame = self._step(batch, batch_idx, "train")
if torch.isinf(loss) or torch.isnan(loss):
opt.zero_grad()
......@@ -480,7 +480,7 @@ class HuBERTFineTuneModule(LightningModule):
"""
opt = self.optimizers()
opt.zero_grad()
with torch.cuda.amp.autocast(enabled=False):
with torch.cuda.amp.autocast(enabled=True):
loss = self._step(batch, batch_idx, "train")
# normalize the loss based on the sum of batch_sie across all GPUs
......
......@@ -964,7 +964,9 @@ class MaskGenerator(Module):
min_space=self.mask_min_space,
)
mask_indices = mask_indices.to(x.device)
x[mask_indices] = self.mask_embedding
# change dtype of mask_embedding to x for mixed-precision training.
# see https://github.com/pytorch/audio/issues/2847 for details.
x[mask_indices] = self.mask_embedding.to(x.dtype)
else:
mask_indices = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment