Unverified Commit bb4f816a authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Patch YOLOS and others (#29353)

Fix issue
parent 44fe1a1c
...@@ -2514,6 +2514,7 @@ class ConditionalDetrLoss(nn.Module): ...@@ -2514,6 +2514,7 @@ class ConditionalDetrLoss(nn.Module):
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1 world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}: if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes) num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes world_size = PartialState().num_processes
......
...@@ -2282,6 +2282,7 @@ class DeformableDetrLoss(nn.Module): ...@@ -2282,6 +2282,7 @@ class DeformableDetrLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1 world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}: if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes) num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes world_size = PartialState().num_processes
......
...@@ -2345,6 +2345,7 @@ class DetaLoss(nn.Module): ...@@ -2345,6 +2345,7 @@ class DetaLoss(nn.Module):
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# Check that we have initialized the distributed state # Check that we have initialized the distributed state
world_size = 1 world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}: if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes) num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes world_size = PartialState().num_processes
......
...@@ -2210,6 +2210,7 @@ class DetrLoss(nn.Module): ...@@ -2210,6 +2210,7 @@ class DetrLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1 world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}: if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes) num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes world_size = PartialState().num_processes
......
...@@ -791,14 +791,15 @@ class Mask2FormerLoss(nn.Module): ...@@ -791,14 +791,15 @@ class Mask2FormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes. Computes the average number of target masks across the batch, for normalization purposes.
""" """
num_masks = sum([len(classes) for classes in class_labels]) num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device) num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1 world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}: if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt) num_masks = reduce(num_masks)
world_size = PartialState().num_processes world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1) num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks_pt return num_masks
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention # Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
......
...@@ -1198,14 +1198,15 @@ class MaskFormerLoss(nn.Module): ...@@ -1198,14 +1198,15 @@ class MaskFormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes. Computes the average number of target masks across the batch, for normalization purposes.
""" """
num_masks = sum([len(classes) for classes in class_labels]) num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device) num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1 world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}: if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt) num_masks = reduce(num_masks)
world_size = PartialState().num_processes world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1) num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks_pt return num_masks
class MaskFormerFPNConvLayer(nn.Module): class MaskFormerFPNConvLayer(nn.Module):
......
...@@ -727,14 +727,15 @@ class OneFormerLoss(nn.Module): ...@@ -727,14 +727,15 @@ class OneFormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes. Computes the average number of target masks across the batch, for normalization purposes.
""" """
num_masks = sum([len(classes) for classes in class_labels]) num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) num_masks = torch.as_tensor([num_masks], dtype=torch.float, device=device)
world_size = 1 world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}: if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt) num_masks = reduce(num_masks)
world_size = PartialState().num_processes world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1) num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks_pt return num_masks
@dataclass @dataclass
......
...@@ -1757,6 +1757,7 @@ class TableTransformerLoss(nn.Module): ...@@ -1757,6 +1757,7 @@ class TableTransformerLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1 world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}: if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes) num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes world_size = PartialState().num_processes
......
...@@ -1079,6 +1079,7 @@ class YolosLoss(nn.Module): ...@@ -1079,6 +1079,7 @@ class YolosLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1 world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}: if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes) num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes world_size = PartialState().num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment