Unverified Commit 633db6be authored by Duo Li's avatar Duo Li Committed by GitHub
Browse files

Update loss_utils.py (#1247)

parent e948c537
...@@ -148,6 +148,7 @@ class WeightedL1Loss(nn.Module): ...@@ -148,6 +148,7 @@ class WeightedL1Loss(nn.Module):
self.code_weights = np.array(code_weights, dtype=np.float32) self.code_weights = np.array(code_weights, dtype=np.float32)
self.code_weights = torch.from_numpy(self.code_weights).cuda() self.code_weights = torch.from_numpy(self.code_weights).cuda()
@torch.amp.cuda.custom_fwd(cast_inputs=torch.float16)
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None): def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None):
""" """
Args: Args:
...@@ -383,4 +384,4 @@ class RegLossCenterNet(nn.Module): ...@@ -383,4 +384,4 @@ class RegLossCenterNet(nn.Module):
else: else:
pred = _transpose_and_gather_feat(output, ind) pred = _transpose_and_gather_feat(output, ind)
loss = _reg_loss(pred, target, mask) loss = _reg_loss(pred, target, mask)
return loss return loss
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment