Unverified Commit 4a66c0d9 authored by Sangbum Daniel Choi's avatar Sangbum Daniel Choi Committed by GitHub
Browse files

enable training mask2former and maskformer for transformers trainer (#28277)

* fix get_num_masks output as [int] to int

* fix loss size from torch.Size([1]) to torch.Size([])
parent 6b8ec258
...@@ -787,7 +787,7 @@ class Mask2FormerLoss(nn.Module): ...@@ -787,7 +787,7 @@ class Mask2FormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes. Computes the average number of target masks across the batch, for normalization purposes.
""" """
num_masks = sum([len(classes) for classes in class_labels]) num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
return num_masks_pt return num_masks_pt
......
...@@ -1193,7 +1193,7 @@ class MaskFormerLoss(nn.Module): ...@@ -1193,7 +1193,7 @@ class MaskFormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes. Computes the average number of target masks across the batch, for normalization purposes.
""" """
num_masks = sum([len(classes) for classes in class_labels]) num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
return num_masks_pt return num_masks_pt
......
...@@ -190,7 +190,7 @@ class Mask2FormerModelTester: ...@@ -190,7 +190,7 @@ class Mask2FormerModelTester:
comm_check_on_output(result) comm_check_on_output(result)
self.parent.assertTrue(result.loss is not None) self.parent.assertTrue(result.loss is not None)
self.parent.assertEqual(result.loss.shape, torch.Size([1])) self.parent.assertEqual(result.loss.shape, torch.Size([]))
@require_torch @require_torch
......
...@@ -189,7 +189,7 @@ class MaskFormerModelTester: ...@@ -189,7 +189,7 @@ class MaskFormerModelTester:
comm_check_on_output(result) comm_check_on_output(result)
self.parent.assertTrue(result.loss is not None) self.parent.assertTrue(result.loss is not None)
self.parent.assertEqual(result.loss.shape, torch.Size([1])) self.parent.assertEqual(result.loss.shape, torch.Size([]))
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment