Unverified Commit 752de3e4 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

change Adam to AdamW (#2412)

parent 39c2c0a7
...@@ -75,7 +75,7 @@ class HuBERTPreTrainModule(LightningModule): ...@@ -75,7 +75,7 @@ class HuBERTPreTrainModule(LightningModule):
raise ValueError(f"Unsupported model name: {model_name}") raise ValueError(f"Unsupported model name: {model_name}")
self.loss = hubert_loss self.loss = hubert_loss
self.optimizer = torch.optim.Adam( self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay
) )
self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates) self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment