Unverified Commit 0eb40855 authored by Tanmay patil's avatar Tanmay patil Committed by GitHub
Browse files

Support : Leverage Accelerate for object detection/segmentation models (#28312)

* made changes for object detection models

* added support for segmentation models.

* Made changes for segmentaion models

* Changed import statements

* solving conflicts

* removed conflicts

* Resolving commits

* Removed conflicts

* Fix : Pixel_mask_value set to False
parent aee11fe4
......@@ -30,6 +30,7 @@ from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
is_timm_available,
is_vision_available,
......@@ -41,6 +42,10 @@ from ...utils.backbone_utils import load_backbone
from .configuration_conditional_detr import ConditionalDetrConfig
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
......@@ -2507,11 +2512,12 @@ class ConditionalDetrLoss(nn.Module):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses
losses = {}
......
......@@ -43,7 +43,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging
from ...utils import is_accelerate_available, is_ninja_available, logging
from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels
......@@ -65,6 +65,10 @@ else:
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
......@@ -2246,11 +2250,11 @@ class DeformableDetrLoss(nn.Module):
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses
losses = {}
......
......@@ -30,6 +30,7 @@ from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
is_timm_available,
is_vision_available,
......@@ -41,6 +42,10 @@ from ...utils.backbone_utils import load_backbone
from .configuration_detr import DetrConfig
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
......@@ -2204,11 +2209,11 @@ class DetrLoss(nn.Module):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses
losses = {}
......
......@@ -34,7 +34,7 @@ from ...file_utils import (
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ...utils import is_accelerate_available, logging
from ...utils.backbone_utils import load_backbone
from .configuration_mask2former import Mask2FormerConfig
......@@ -42,6 +42,10 @@ from .configuration_mask2former import Mask2FormerConfig
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
logger = logging.get_logger(__name__)
......@@ -788,6 +792,12 @@ class Mask2FormerLoss(nn.Module):
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt
......
......@@ -31,6 +31,7 @@ from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
logging,
replace_return_docstrings,
......@@ -42,6 +43,10 @@ from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
......@@ -1194,6 +1199,12 @@ class MaskFormerLoss(nn.Module):
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt
......
......@@ -31,6 +31,7 @@ from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
logging,
replace_return_docstrings,
......@@ -40,6 +41,10 @@ from ...utils.backbone_utils import load_backbone
from .configuration_oneformer import OneFormerConfig
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
logger = logging.get_logger(__name__)
......@@ -723,6 +728,12 @@ class OneFormerLoss(nn.Module):
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt
......
......@@ -30,6 +30,7 @@ from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
is_timm_available,
is_vision_available,
......@@ -50,6 +51,10 @@ if is_timm_available():
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "TableTransformerConfig"
......@@ -1751,11 +1756,11 @@ class TableTransformerLoss(nn.Module):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses
losses = {}
......
......@@ -1297,7 +1297,7 @@ class YolosImageProcessor(BaseImageProcessor):
encoded_inputs = self.pad(
images,
annotations=annotations,
return_pixel_mask=True,
return_pixel_mask=False,
data_format=data_format,
input_data_format=input_data_format,
update_bboxes=do_convert_annotations,
......
......@@ -33,6 +33,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
is_vision_available,
logging,
......@@ -48,6 +49,9 @@ if is_scipy_available():
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
logger = logging.get_logger(__name__)
......@@ -1074,11 +1078,11 @@ class YolosLoss(nn.Module):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses
losses = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment