Unverified Commit 633db6be authored by Duo Li's avatar Duo Li Committed by GitHub
Browse files

Update loss_utils.py (#1247)

parent e948c537
......@@ -148,6 +148,7 @@ class WeightedL1Loss(nn.Module):
self.code_weights = np.array(code_weights, dtype=np.float32)
self.code_weights = torch.from_numpy(self.code_weights).cuda()
@torch.amp.cuda.custom_fwd(cast_inputs=torch.float16)
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None):
"""
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment