Unverified Commit b599b192 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[ConvBert] Fix #21523 (#21849)

* fix reshaping
Fixes #21523

* add test

* styling

* last fixes

* Update src/transformers/models/convbert/modeling_convbert.py

* code quallity
parent 44e3e3fb
...@@ -316,7 +316,7 @@ class ConvBertSelfAttention(nn.Module): ...@@ -316,7 +316,7 @@ class ConvBertSelfAttention(nn.Module):
if config.hidden_size % self.num_attention_heads != 0: if config.hidden_size % self.num_attention_heads != 0:
raise ValueError("hidden_size should be divisible by num_attention_heads") raise ValueError("hidden_size should be divisible by num_attention_heads")
self.attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size) self.query = nn.Linear(config.hidden_size, self.all_head_size)
...@@ -413,7 +413,10 @@ class ConvBertSelfAttention(nn.Module): ...@@ -413,7 +413,10 @@ class ConvBertSelfAttention(nn.Module):
conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size]) conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size])
context_layer = torch.cat([context_layer, conv_out], 2) context_layer = torch.cat([context_layer, conv_out], 2)
new_context_layer_shape = context_layer.size()[:-2] + (self.head_ratio * self.all_head_size,) # conv and context
new_context_layer_shape = context_layer.size()[:-2] + (
self.num_attention_heads * self.attention_head_size * 2,
)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -459,6 +459,11 @@ class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase ...@@ -459,6 +459,11 @@ class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
result = model(inputs_embeds=inputs_embeds) result = model(inputs_embeds=inputs_embeds)
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size)) self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
def test_reducing_attention_heads(self):
config, *inputs_dict = self.model_tester.prepare_config_and_inputs()
config.head_ratio = 4
self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict)
@require_torch @require_torch
class ConvBertModelIntegrationTest(unittest.TestCase): class ConvBertModelIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment