"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "87b04b0fc3998c28f54ae26aea08df088c7df83e"
Unverified Commit 659829b6 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

MaskFormer - enable return_dict in order to compile (#25052)

* Enable return_dict in order to compile

* Update tests
parent b914ec98
...@@ -1254,11 +1254,16 @@ class MaskFormerPixelDecoder(nn.Module): ...@@ -1254,11 +1254,16 @@ class MaskFormerPixelDecoder(nn.Module):
self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs) self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs)
self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1) self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1)
def forward(self, features: List[Tensor], output_hidden_states: bool = False) -> MaskFormerPixelDecoderOutput: def forward(
self, features: List[Tensor], output_hidden_states: bool = False, return_dict: bool = True
) -> MaskFormerPixelDecoderOutput:
fpn_features = self.fpn(features) fpn_features = self.fpn(features)
# we use the last feature map # we use the last feature map
last_feature_projected = self.mask_projection(fpn_features[-1]) last_feature_projected = self.mask_projection(fpn_features[-1])
if not return_dict:
return (last_feature_projected, tuple(fpn_features)) if output_hidden_states else (last_feature_projected,)
return MaskFormerPixelDecoderOutput( return MaskFormerPixelDecoderOutput(
last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else () last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else ()
) )
...@@ -1387,9 +1392,20 @@ class MaskFormerPixelLevelModule(nn.Module): ...@@ -1387,9 +1392,20 @@ class MaskFormerPixelLevelModule(nn.Module):
lateral_widths=feature_channels[:-1], lateral_widths=feature_channels[:-1],
) )
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> MaskFormerPixelLevelModuleOutput: def forward(
self, pixel_values: Tensor, output_hidden_states: bool = False, return_dict: bool = True
) -> MaskFormerPixelLevelModuleOutput:
features = self.encoder(pixel_values).feature_maps features = self.encoder(pixel_values).feature_maps
decoder_output = self.decoder(features, output_hidden_states) decoder_output = self.decoder(features, output_hidden_states, return_dict=return_dict)
if not return_dict:
last_hidden_state = decoder_output[0]
outputs = (features[-1], last_hidden_state)
if output_hidden_states:
hidden_states = decoder_output[1]
outputs = outputs + (tuple(features),) + (hidden_states,)
return outputs
return MaskFormerPixelLevelModuleOutput( return MaskFormerPixelLevelModuleOutput(
# the last feature is actually the output from the last layer # the last feature is actually the output from the last layer
encoder_last_hidden_state=features[-1], encoder_last_hidden_state=features[-1],
...@@ -1414,7 +1430,11 @@ class MaskFormerTransformerModule(nn.Module): ...@@ -1414,7 +1430,11 @@ class MaskFormerTransformerModule(nn.Module):
self.decoder = DetrDecoder(config=config.decoder_config) self.decoder = DetrDecoder(config=config.decoder_config)
def forward( def forward(
self, image_features: Tensor, output_hidden_states: bool = False, output_attentions: bool = False self,
image_features: Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False,
return_dict: Optional[bool] = None,
) -> DetrDecoderOutput: ) -> DetrDecoderOutput:
if self.input_projection is not None: if self.input_projection is not None:
image_features = self.input_projection(image_features) image_features = self.input_projection(image_features)
...@@ -1438,7 +1458,7 @@ class MaskFormerTransformerModule(nn.Module): ...@@ -1438,7 +1458,7 @@ class MaskFormerTransformerModule(nn.Module):
query_position_embeddings=queries_embeddings, query_position_embeddings=queries_embeddings,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=None, return_dict=return_dict,
) )
return decoder_output return decoder_output
...@@ -1593,9 +1613,11 @@ class MaskFormerModel(MaskFormerPreTrainedModel): ...@@ -1593,9 +1613,11 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
if pixel_mask is None: if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states) pixel_level_module_output = self.pixel_level_module(
image_features = pixel_level_module_output.encoder_last_hidden_state pixel_values, output_hidden_states, return_dict=return_dict
pixel_embeddings = pixel_level_module_output.decoder_last_hidden_state )
image_features = pixel_level_module_output[0]
pixel_embeddings = pixel_level_module_output[1]
transformer_module_output = self.transformer_module(image_features, output_hidden_states, output_attentions) transformer_module_output = self.transformer_module(image_features, output_hidden_states, output_attentions)
queries = transformer_module_output.last_hidden_state queries = transformer_module_output.last_hidden_state
...@@ -1606,9 +1628,9 @@ class MaskFormerModel(MaskFormerPreTrainedModel): ...@@ -1606,9 +1628,9 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
hidden_states = None hidden_states = None
if output_hidden_states: if output_hidden_states:
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states encoder_hidden_states = pixel_level_module_output[2]
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states pixel_decoder_hidden_states = pixel_level_module_output[3]
transformer_decoder_hidden_states = transformer_module_output.hidden_states transformer_decoder_hidden_states = transformer_module_output[1]
hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states
output = MaskFormerModelOutput( output = MaskFormerModelOutput(
...@@ -1803,13 +1825,25 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1803,13 +1825,25 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs: MaskFormerModelOutput = self.model( raw_outputs = self.model(
pixel_values, pixel_values,
pixel_mask, pixel_mask,
output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
return_dict=True, return_dict=return_dict,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
# We need to have raw_outputs optionally be returned as a dict to use torch.compile. For backwards
# compatibility we convert to a dataclass for the rest of the model logic
outputs = MaskFormerModelOutput(
encoder_last_hidden_state=raw_outputs[0],
pixel_decoder_last_hidden_state=raw_outputs[1],
transformer_decoder_last_hidden_state=raw_outputs[2],
encoder_hidden_states=raw_outputs[3] if output_hidden_states else None,
pixel_decoder_hidden_states=raw_outputs[4] if output_hidden_states else None,
transformer_decoder_hidden_states=raw_outputs[5] if output_hidden_states else None,
hidden_states=raw_outputs[6] if output_hidden_states else None,
attentions=raw_outputs[-1] if output_attentions else None,
)
loss, loss_dict, auxiliary_logits = None, None, None loss, loss_dict, auxiliary_logits = None, None, None
...@@ -1827,16 +1861,18 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1827,16 +1861,18 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
if not output_auxiliary_logits: if not output_auxiliary_logits:
auxiliary_logits = None auxiliary_logits = None
output = MaskFormerForInstanceSegmentationOutput( if not return_dict:
output = tuple(
v
for v in (loss, class_queries_logits, masks_queries_logits, auxiliary_logits, *outputs.values())
if v is not None
)
return output
return MaskFormerForInstanceSegmentationOutput(
loss=loss, loss=loss,
**outputs, **outputs,
class_queries_logits=class_queries_logits, class_queries_logits=class_queries_logits,
masks_queries_logits=masks_queries_logits, masks_queries_logits=masks_queries_logits,
auxiliary_logits=auxiliary_logits, auxiliary_logits=auxiliary_logits,
) )
if not return_dict:
output = tuple(v for v in output.values())
if loss is not None:
output = ((loss)) + output
return output
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch MaskFormer model. """ """ Testing suite for the PyTorch MaskFormer model. """
import copy
import inspect import inspect
import unittest import unittest
...@@ -54,6 +55,8 @@ class MaskFormerModelTester: ...@@ -54,6 +55,8 @@ class MaskFormerModelTester:
max_size=32 * 6, max_size=32 * 6,
num_labels=4, num_labels=4,
mask_feature_size=32, mask_feature_size=32,
num_hidden_layers=2,
num_attention_heads=2,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -65,6 +68,9 @@ class MaskFormerModelTester: ...@@ -65,6 +68,9 @@ class MaskFormerModelTester:
self.max_size = max_size self.max_size = max_size
self.num_labels = num_labels self.num_labels = num_labels
self.mask_feature_size = mask_feature_size self.mask_feature_size = mask_feature_size
# This is passed to the decoder config. We add it to the model tester here for testing
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to( pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to(
...@@ -91,11 +97,12 @@ class MaskFormerModelTester: ...@@ -91,11 +97,12 @@ class MaskFormerModelTester:
), ),
decoder_config=DetrConfig( decoder_config=DetrConfig(
decoder_ffn_dim=64, decoder_ffn_dim=64,
decoder_layers=2, decoder_layers=self.num_hidden_layers,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=64, encoder_ffn_dim=64,
encoder_layers=2, encoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
num_queries=self.num_queries, num_queries=self.num_queries,
decoder_attention_heads=2,
d_model=self.mask_feature_size, d_model=self.mask_feature_size,
), ),
mask_feature_size=self.mask_feature_size, mask_feature_size=self.mask_feature_size,
...@@ -196,6 +203,27 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -196,6 +203,27 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
self.model_tester = MaskFormerModelTester(self) self.model_tester = MaskFormerModelTester(self)
self.config_tester = ConfigTester(self, config_class=MaskFormerConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=MaskFormerConfig, has_text_modality=False)
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
if model_class in [MaskFormerForInstanceSegmentation]:
inputs_dict["mask_labels"] = torch.zeros(
(
self.model_tester.batch_size,
self.model_tester.num_labels,
self.model_tester.min_size,
self.model_tester.max_size,
),
dtype=torch.float32,
device=torch_device,
)
inputs_dict["class_labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.num_labels), dtype=torch.long, device=torch_device
)
return inputs_dict
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
...@@ -265,26 +293,47 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -265,26 +293,47 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
self.model_tester.create_and_check_maskformer_model(config, **inputs, output_hidden_states=True) self.model_tester.create_and_check_maskformer_model(config, **inputs, output_hidden_states=True)
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config).to(torch_device) inputs_dict["output_attentions"] = True
outputs = model(**inputs, output_attentions=True) inputs_dict["output_hidden_states"] = False
self.assertTrue(outputs.attentions is not None) config.return_dict = True
model = model_class(config)
def test_training(self): model.to(torch_device)
if not self.model_tester.is_training: model.eval()
return with torch.no_grad():
# only MaskFormerForInstanceSegmentation has the loss outputs = model(**self._prepare_for_class(inputs_dict, model_class))
model_class = self.all_model_classes[1] attentions = outputs.attentions
config, pixel_values, pixel_mask, mask_labels, class_labels = self.model_tester.prepare_config_and_inputs() self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# Check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
# encoder_hidden_states, pixel_decoder_hidden_states, transformer_decoder_hidden_states, hidden_states
added_hidden_states = 4
self.assertEqual(out_len + added_hidden_states, len(outputs))
loss = model(pixel_values, mask_labels=mask_labels, class_labels=class_labels).loss self_attentions = outputs.attentions
loss.backward() self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
# only MaskFormerForInstanceSegmentation has the loss # only MaskFormerForInstanceSegmentation has the loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment