"vscode:/vscode.git/clone" did not exist on "f2c4ce7e339f4a2f8aaacb392496bc1a5743881f"
Unverified Commit b9752161 authored by Pavel Iakubovskii's avatar Pavel Iakubovskii Committed by GitHub
Browse files

Fix RT-DETR cache for generate_anchors (#31671)

* Fix cache and type conversion

* Add test

* Fixup

* nit

* [run slow] rt_detr

* Fix test

* Fixup

* [run slow] rt_detr

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
parent 534cbf8a
...@@ -1656,7 +1656,11 @@ class RTDetrModel(RTDetrPreTrainedModel): ...@@ -1656,7 +1656,11 @@ class RTDetrModel(RTDetrPreTrainedModel):
param.requires_grad_(True) param.requires_grad_(True)
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device="cpu"): def generate_anchors(self, spatial_shapes=None, grid_size=0.05):
# We always generate anchors in float32 to preserve equivalence between
# dynamic and static anchor inference
dtype = torch.float32
if spatial_shapes is None: if spatial_shapes is None:
spatial_shapes = [ spatial_shapes = [
[int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
...@@ -1674,7 +1678,7 @@ class RTDetrModel(RTDetrPreTrainedModel): ...@@ -1674,7 +1678,7 @@ class RTDetrModel(RTDetrPreTrainedModel):
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
# define the valid range for anchor coordinates # define the valid range for anchor coordinates
eps = 1e-2 eps = 1e-2
anchors = torch.concat(anchors, 1).to(device) anchors = torch.concat(anchors, 1)
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
anchors = torch.log(anchors / (1 - anchors)) anchors = torch.log(anchors / (1 - anchors))
anchors = torch.where(valid_mask, anchors, torch.inf) anchors = torch.where(valid_mask, anchors, torch.inf)
...@@ -1769,15 +1773,15 @@ class RTDetrModel(RTDetrPreTrainedModel): ...@@ -1769,15 +1773,15 @@ class RTDetrModel(RTDetrPreTrainedModel):
# Prepare encoder inputs (by flattening) # Prepare encoder inputs (by flattening)
source_flatten = [] source_flatten = []
spatial_shapes = [] spatial_shapes_list = []
for level, source in enumerate(sources): for level, source in enumerate(sources):
batch_size, num_channels, height, width = source.shape batch_size, num_channels, height, width = source.shape
spatial_shape = (height, width) spatial_shape = (height, width)
spatial_shapes.append(spatial_shape) spatial_shapes_list.append(spatial_shape)
source = source.flatten(2).transpose(1, 2) source = source.flatten(2).transpose(1, 2)
source_flatten.append(source) source_flatten.append(source)
source_flatten = torch.cat(source_flatten, 1) source_flatten = torch.cat(source_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
# prepare denoising training # prepare denoising training
...@@ -1805,9 +1809,14 @@ class RTDetrModel(RTDetrPreTrainedModel): ...@@ -1805,9 +1809,14 @@ class RTDetrModel(RTDetrPreTrainedModel):
# prepare input for decoder # prepare input for decoder
if self.training or self.config.anchor_image_size is None: if self.training or self.config.anchor_image_size is None:
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device, dtype=dtype) # Pass spatial_shapes as tuple to make it hashable and make sure
# lru_cache is working for generate_anchors()
spatial_shapes_tuple = tuple(spatial_shapes_list)
anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple)
else: else:
anchors, valid_mask = self.anchors.to(device, dtype), self.valid_mask.to(device, dtype) anchors, valid_mask = self.anchors, self.valid_mask
anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
# use the valid_mask to selectively retain values in the feature map where the mask is `True` # use the valid_mask to selectively retain values in the feature map where the mask is `True`
memory = valid_mask.to(source_flatten.dtype) * source_flatten memory = valid_mask.to(source_flatten.dtype) * source_flatten
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import inspect import inspect
import math import math
import tempfile
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
...@@ -630,6 +631,48 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -630,6 +631,48 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class)) _ = model(**self._prepare_for_class(inputs_dict, model_class))
@parameterized.expand(["float32", "float16", "bfloat16"])
@require_torch_gpu
@slow
def test_inference_equivalence_for_static_and_dynamic_anchors(self, torch_dtype_str):
torch_dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[torch_dtype_str]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
h, w = inputs_dict["pixel_values"].shape[-2:]
# convert inputs to the desired dtype
for key, tensor in inputs_dict.items():
if tensor.dtype == torch.float32:
inputs_dict[key] = tensor.to(torch_dtype)
for model_class in self.all_model_classes:
with tempfile.TemporaryDirectory() as tmpdirname:
model_class(config).save_pretrained(tmpdirname)
model_static = model_class.from_pretrained(
tmpdirname, anchor_image_size=[h, w], device_map=torch_device, torch_dtype=torch_dtype
).eval()
model_dynamic = model_class.from_pretrained(
tmpdirname, anchor_image_size=None, device_map=torch_device, torch_dtype=torch_dtype
).eval()
self.assertIsNotNone(model_static.config.anchor_image_size)
self.assertIsNone(model_dynamic.config.anchor_image_size)
with torch.no_grad():
outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class))
outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class))
self.assertTrue(
torch.allclose(
outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=1e-4, atol=1e-4
),
f"Max diff: {(outputs_static.last_hidden_state - outputs_dynamic.last_hidden_state).abs().max()}",
)
TOLERANCE = 1e-4 TOLERANCE = 1e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment