Unverified Commit b1ec7454 authored by Pavel Iakubovskii's avatar Pavel Iakubovskii Committed by GitHub
Browse files

Fix RT-DETR inference with float16 and bfloat16 (#31639)



* [run_slow] rt_detr

* Fix positional embeddings and anchors dtypes

* [run slow] rt_detr

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fixup

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 3f93fd06
...@@ -1359,7 +1359,7 @@ class RTDetrHybridEncoder(nn.Module): ...@@ -1359,7 +1359,7 @@ class RTDetrHybridEncoder(nn.Module):
if self.training or self.eval_size is None: if self.training or self.eval_size is None:
pos_embed = self.build_2d_sincos_position_embedding( pos_embed = self.build_2d_sincos_position_embedding(
width, height, self.encoder_hidden_dim, self.positional_encoding_temperature width, height, self.encoder_hidden_dim, self.positional_encoding_temperature
).to(src_flatten.device) ).to(src_flatten.device, src_flatten.dtype)
else: else:
pos_embed = None pos_embed = None
...@@ -1801,12 +1801,13 @@ class RTDetrModel(RTDetrPreTrainedModel): ...@@ -1801,12 +1801,13 @@ class RTDetrModel(RTDetrPreTrainedModel):
batch_size = len(source_flatten) batch_size = len(source_flatten)
device = source_flatten.device device = source_flatten.device
dtype = source_flatten.dtype
# prepare input for decoder # prepare input for decoder
if self.training or self.config.anchor_image_size is None: if self.training or self.config.anchor_image_size is None:
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device) anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device, dtype=dtype)
else: else:
anchors, valid_mask = self.anchors.to(device), self.valid_mask.to(device) anchors, valid_mask = self.anchors.to(device, dtype), self.valid_mask.to(device, dtype)
# use the valid_mask to selectively retain values in the feature map where the mask is `True` # use the valid_mask to selectively retain values in the feature map where the mask is `True`
memory = valid_mask.to(source_flatten.dtype) * source_flatten memory = valid_mask.to(source_flatten.dtype) * source_flatten
......
...@@ -18,6 +18,8 @@ import inspect ...@@ -18,6 +18,8 @@ import inspect
import math import math
import unittest import unittest
from parameterized import parameterized
from transformers import ( from transformers import (
RTDetrConfig, RTDetrConfig,
RTDetrImageProcessor, RTDetrImageProcessor,
...@@ -25,7 +27,7 @@ from transformers import ( ...@@ -25,7 +27,7 @@ from transformers import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import require_torch, require_vision, torch_device from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -606,6 +608,28 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -606,6 +608,28 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
@parameterized.expand(["float32", "float16", "bfloat16"])
@require_torch_gpu
@slow
def test_inference_with_different_dtypes(self, torch_dtype_str):
torch_dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[torch_dtype_str]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device).to(torch_dtype)
model.eval()
for key, tensor in inputs_dict.items():
if tensor.dtype == torch.float32:
inputs_dict[key] = tensor.to(torch_dtype)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
TOLERANCE = 1e-4 TOLERANCE = 1e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment