Unverified Commit 080a9711 authored by Pedro Lira's avatar Pedro Lira Committed by GitHub
Browse files

Add mask2former fp16 support (#25093)

* Add mask2former fp16 support

* Clear consistency/quality issues

* Fix consistency/quality (2)

* Add integration test for mask2former (fp16 case)

* Fix code quality

* Add integration test for maskformer (fp16 case)

* Add integration test for oneformer (fp16 case)

* Remove slow decorator from fp16 tests

* Fix lint

* Remove usage of full inference and value checks for fp16

* Temporarily comment slow for {mask, mask2, one}former

* Add fp16 support to oneformer

* Revert "Temporarily comment slow for {mask, mask2, one}former"

This reverts commit e5371edabd301cf56079def0421a0a87df307cb0.

* Remove dtype conversion noop
parent 5ee9693a
...@@ -864,15 +864,15 @@ class Mask2FormerSinePositionEmbedding(nn.Module): ...@@ -864,15 +864,15 @@ class Mask2FormerSinePositionEmbedding(nn.Module):
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None: if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask not_mask = (~mask).to(x.dtype)
y_embed = not_mask.cumsum(1, dtype=torch.float32) y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2, dtype=torch.float32) x_embed = not_mask.cumsum(2)
if self.normalize: if self.normalize:
eps = 1e-6 eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
...@@ -1104,8 +1104,8 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module): ...@@ -1104,8 +1104,8 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
reference_points_list = [] reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes): for lvl, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid( ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
indexing="ij", indexing="ij",
) )
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
...@@ -1267,14 +1267,14 @@ class Mask2FormerPixelDecoder(nn.Module): ...@@ -1267,14 +1267,14 @@ class Mask2FormerPixelDecoder(nn.Module):
self.lateral_convolutions = lateral_convs[::-1] self.lateral_convolutions = lateral_convs[::-1]
self.output_convolutions = output_convs[::-1] self.output_convolutions = output_convs[::-1]
def get_valid_ratio(self, mask): def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps.""" """Get the valid ratio of all feature maps."""
_, height, width = mask.shape _, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1) valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1) valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height valid_ratio_heigth = valid_height.to(dtype) / height
valid_ratio_width = valid_width.float() / width valid_ratio_width = valid_width.to(dtype) / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio return valid_ratio
...@@ -1295,8 +1295,8 @@ class Mask2FormerPixelDecoder(nn.Module): ...@@ -1295,8 +1295,8 @@ class Mask2FormerPixelDecoder(nn.Module):
input_embeds = [] input_embeds = []
position_embeddings = [] position_embeddings = []
for level, x in enumerate(features[::-1][: self.num_feature_levels]): for level, x in enumerate(features[::-1][: self.num_feature_levels]):
input_embeds.append(self.input_projections[level](x.float())) input_embeds.append(self.input_projections[level](x))
position_embeddings.append(self.position_embedding(x.float())) position_embeddings.append(self.position_embedding(x))
masks = [ masks = [
torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in input_embeds torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in input_embeds
...@@ -1313,7 +1313,7 @@ class Mask2FormerPixelDecoder(nn.Module): ...@@ -1313,7 +1313,7 @@ class Mask2FormerPixelDecoder(nn.Module):
level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1) level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(mask) for mask in masks], 1) valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1)
# Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder
if encoder_outputs is None: if encoder_outputs is None:
...@@ -1351,7 +1351,7 @@ class Mask2FormerPixelDecoder(nn.Module): ...@@ -1351,7 +1351,7 @@ class Mask2FormerPixelDecoder(nn.Module):
for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]): for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]):
lateral_conv = self.lateral_convolutions[idx] lateral_conv = self.lateral_convolutions[idx]
output_conv = self.output_convolutions[idx] output_conv = self.output_convolutions[idx]
current_fpn = lateral_conv(feature.float()) current_fpn = lateral_conv(feature)
# Following FPN implementation, we use nearest upsampling here # Following FPN implementation, we use nearest upsampling here
out = current_fpn + nn.functional.interpolate( out = current_fpn + nn.functional.interpolate(
......
...@@ -1290,15 +1290,15 @@ class MaskFormerSinePositionEmbedding(nn.Module): ...@@ -1290,15 +1290,15 @@ class MaskFormerSinePositionEmbedding(nn.Module):
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None: if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask not_mask = (~mask).to(x.dtype)
y_embed = not_mask.cumsum(1, dtype=torch.float32) y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2, dtype=torch.float32) x_embed = not_mask.cumsum(2)
if self.normalize: if self.normalize:
eps = 1e-6 eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
......
...@@ -1179,8 +1179,8 @@ class OneFormerPixelDecoderEncoderOnly(nn.Module): ...@@ -1179,8 +1179,8 @@ class OneFormerPixelDecoderEncoderOnly(nn.Module):
reference_points_list = [] reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes): for lvl, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid( ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
) )
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width)
...@@ -1352,14 +1352,14 @@ class OneFormerPixelDecoder(nn.Module): ...@@ -1352,14 +1352,14 @@ class OneFormerPixelDecoder(nn.Module):
self.lateral_convs = lateral_convs[::-1] self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1] self.output_convs = output_convs[::-1]
def get_valid_ratio(self, mask): def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps.""" """Get the valid ratio of all feature maps."""
_, height, width = mask.shape _, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1) valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1) valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height valid_ratio_heigth = valid_height.to(dtype) / height
valid_ratio_width = valid_width.float() / width valid_ratio_width = valid_width.to(dtype) / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio return valid_ratio
...@@ -1380,9 +1380,8 @@ class OneFormerPixelDecoder(nn.Module): ...@@ -1380,9 +1380,8 @@ class OneFormerPixelDecoder(nn.Module):
sources = [] sources = []
position_embeddings_list = [] position_embeddings_list = []
for level, source in enumerate(features[::-1][: self.num_feature_levels]): for level, source in enumerate(features[::-1][: self.num_feature_levels]):
feats = source.float() sources.append(self.input_projections[level](source))
sources.append(self.input_projections[level](feats)) position_embeddings_list.append(self.position_embedding(source))
position_embeddings_list.append(self.position_embedding(feats))
masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in sources] masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in sources]
...@@ -1407,8 +1406,7 @@ class OneFormerPixelDecoder(nn.Module): ...@@ -1407,8 +1406,7 @@ class OneFormerPixelDecoder(nn.Module):
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
valid_ratios = valid_ratios.float()
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
# Also provide spatial_shapes, level_start_index and valid_ratios # Also provide spatial_shapes, level_start_index and valid_ratios
...@@ -1445,7 +1443,6 @@ class OneFormerPixelDecoder(nn.Module): ...@@ -1445,7 +1443,6 @@ class OneFormerPixelDecoder(nn.Module):
# append `out` with extra FPN levels # append `out` with extra FPN levels
# Reverse feature maps into top-down order (from low to high resolution) # Reverse feature maps into top-down order (from low to high resolution)
for idx, feats in enumerate(features[: self.num_fpn_levels][::-1]): for idx, feats in enumerate(features[: self.num_fpn_levels][::-1]):
feats = feats.float()
lateral_conv = self.lateral_convs[idx] lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx] output_conv = self.output_convs[idx]
cur_fpn = lateral_conv(feats) cur_fpn = lateral_conv(feats)
...@@ -2396,15 +2393,15 @@ class OneFormerSinePositionEmbedding(nn.Module): ...@@ -2396,15 +2393,15 @@ class OneFormerSinePositionEmbedding(nn.Module):
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None: if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask not_mask = (~mask).to(x.dtype)
y_embed = not_mask.cumsum(1, dtype=torch.float32) y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2, dtype=torch.float32) x_embed = not_mask.cumsum(2)
if self.normalize: if self.normalize:
eps = 1e-6 eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
...@@ -2744,7 +2741,7 @@ class OneFormerTaskModel(nn.Module): ...@@ -2744,7 +2741,7 @@ class OneFormerTaskModel(nn.Module):
) )
def forward(self, inputs: Tensor) -> Tensor: def forward(self, inputs: Tensor) -> Tensor:
task_tokens = self.task_mlp(inputs.float()) task_tokens = self.task_mlp(inputs)
return task_tokens return task_tokens
...@@ -2980,7 +2977,7 @@ class OneFormerModel(OneFormerPreTrainedModel): ...@@ -2980,7 +2977,7 @@ class OneFormerModel(OneFormerPreTrainedModel):
multi_scale_features = pixel_level_module_output.decoder_features multi_scale_features = pixel_level_module_output.decoder_features
mask_features = pixel_level_module_output.decoder_last_feature mask_features = pixel_level_module_output.decoder_last_feature
task_token = self.task_encoder(task_inputs) task_token = self.task_encoder(task_inputs.to(self.dtype))
if self.is_training: if self.is_training:
text_queries = self.text_mapper(text_inputs) text_queries = self.text_mapper(text_inputs)
......
...@@ -21,7 +21,14 @@ import numpy as np ...@@ -21,7 +21,14 @@ import numpy as np
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -420,6 +427,20 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase): ...@@ -420,6 +427,20 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
@require_torch_gpu
def test_inference_fp16(self):
model = (
Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
.to(torch_device, dtype=torch.float16)
.eval()
)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device, dtype=torch.float16)
with torch.no_grad():
_ = model(**inputs)
def test_with_segmentation_maps_and_loss(self): def test_with_segmentation_maps_and_loss(self):
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
image_processor = self.default_image_processor image_processor = self.default_image_processor
......
...@@ -22,7 +22,14 @@ import numpy as np ...@@ -22,7 +22,14 @@ import numpy as np
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -509,6 +516,20 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -509,6 +516,20 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
@require_torch_gpu
def test_inference_fp16(self):
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
.to(torch_device, dtype=torch.float16)
.eval()
)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device, dtype=torch.float16)
with torch.no_grad():
_ = model(**inputs)
def test_with_segmentation_maps_and_loss(self): def test_with_segmentation_maps_and_loss(self):
model = ( model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco") MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
......
...@@ -22,7 +22,14 @@ import numpy as np ...@@ -22,7 +22,14 @@ import numpy as np
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import OneFormerConfig, is_torch_available, is_vision_available from transformers import OneFormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -533,6 +540,20 @@ class OneFormerModelIntegrationTest(unittest.TestCase): ...@@ -533,6 +540,20 @@ class OneFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
@require_torch_gpu
def test_inference_fp16(self):
model = (
OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
.to(torch_device, dtype=torch.float16)
.eval()
)
processor = self.default_processor
image = prepare_img()
inputs = processor(image, ["semantic"], return_tensors="pt").to(torch_device, dtype=torch.float16)
with torch.no_grad():
_ = model(**inputs)
def test_with_segmentation_maps_and_loss(self): def test_with_segmentation_maps_and_loss(self):
dummy_model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints) dummy_model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
processor = self.default_processor processor = self.default_processor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment