Unverified Commit 31f6758a authored by jihan.yang's avatar jihan.yang Committed by GitHub
Browse files

Fixbug: torch.cuda.amp rather than torch.amp.cuda in loss (#1252)

* Fixbug: torch.cuda.amp rather than torch.amp.cuda
parent 633db6be
...@@ -148,7 +148,7 @@ class WeightedL1Loss(nn.Module): ...@@ -148,7 +148,7 @@ class WeightedL1Loss(nn.Module):
self.code_weights = np.array(code_weights, dtype=np.float32) self.code_weights = np.array(code_weights, dtype=np.float32)
self.code_weights = torch.from_numpy(self.code_weights).cuda() self.code_weights = torch.from_numpy(self.code_weights).cuda()
@torch.amp.cuda.custom_fwd(cast_inputs=torch.float16) @torch.cuda.amp.custom_fwd(cast_inputs=torch.float16)
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None): def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None):
""" """
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment