Unverified Commit bb4f816a authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Patch YOLOS and others (#29353)

Fix issue
parent 44fe1a1c
......@@ -2514,6 +2514,7 @@ class ConditionalDetrLoss(nn.Module):
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
......
......@@ -2282,6 +2282,7 @@ class DeformableDetrLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
......
......@@ -2345,6 +2345,7 @@ class DetaLoss(nn.Module):
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# Check that we have initialized the distributed state
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
......
......@@ -2210,6 +2210,7 @@ class DetrLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
......
......@@ -791,14 +791,15 @@ class Mask2FormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
num_masks = reduce(num_masks)
world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt
num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
......
......@@ -1198,14 +1198,15 @@ class MaskFormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
num_masks = reduce(num_masks)
world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt
num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks
class MaskFormerFPNConvLayer(nn.Module):
......
......@@ -727,14 +727,15 @@ class OneFormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
num_masks = torch.as_tensor([num_masks], dtype=torch.float, device=device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
num_masks = reduce(num_masks)
world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt
num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks
@dataclass
......
......@@ -1757,6 +1757,7 @@ class TableTransformerLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
......
......@@ -1079,6 +1079,7 @@ class YolosLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment