"docs/zh_cn/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "a0939977a3b3c34c925c565c3fd3dcbe5d09e23c"
Commit 030646c0 authored by Zhaoheng Ni's avatar Zhaoheng Ni
Browse files

Enable mixed precision training for hubert_pretrain_model (#2854)

Summary:
address https://github.com/pytorch/audio/issues/2847

In mixed precision training, the dtype of `mask_embedding` is **not** converted to fp16 automatically. This PR addresses the issue by changing the dtype of `mask_embedding` to `x` to enable mixed precision training.

Pull Request resolved: https://github.com/pytorch/audio/pull/2854

Reviewed By: carolineechen

Differential Revision: D41343486

Pulled By: nateanl

fbshipit-source-id: 4a5cbb429ff8ba5d3c439a3d5acb5094f66bf705
parent 980528e9
...@@ -247,7 +247,7 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -247,7 +247,7 @@ class HuBERTPreTrainModule(LightningModule):
""" """
opt = self.optimizers() opt = self.optimizers()
opt.zero_grad() opt.zero_grad()
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=True):
loss, num_frame = self._step(batch, batch_idx, "train") loss, num_frame = self._step(batch, batch_idx, "train")
if torch.isinf(loss) or torch.isnan(loss): if torch.isinf(loss) or torch.isnan(loss):
opt.zero_grad() opt.zero_grad()
...@@ -480,7 +480,7 @@ class HuBERTFineTuneModule(LightningModule): ...@@ -480,7 +480,7 @@ class HuBERTFineTuneModule(LightningModule):
""" """
opt = self.optimizers() opt = self.optimizers()
opt.zero_grad() opt.zero_grad()
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=True):
loss = self._step(batch, batch_idx, "train") loss = self._step(batch, batch_idx, "train")
# normalize the loss based on the sum of batch_sie across all GPUs # normalize the loss based on the sum of batch_sie across all GPUs
......
...@@ -964,7 +964,9 @@ class MaskGenerator(Module): ...@@ -964,7 +964,9 @@ class MaskGenerator(Module):
min_space=self.mask_min_space, min_space=self.mask_min_space,
) )
mask_indices = mask_indices.to(x.device) mask_indices = mask_indices.to(x.device)
x[mask_indices] = self.mask_embedding # change dtype of mask_embedding to x for mixed-precision training.
# see https://github.com/pytorch/audio/issues/2847 for details.
x[mask_indices] = self.mask_embedding.to(x.dtype)
else: else:
mask_indices = None mask_indices = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment