"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5d2d51a0fbf2b026bc154754ffece9aba336b2e8"
Unverified Commit 77db257e authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Fix the issue of using only inputs_embeds in convbert model (#21398)

* Fix the input embeds issue with tests

* Fix black and isort issue

* Clean up tests

* Add slow tag to the test introduced

* Incorporate PR feedbacks
parent 65b5035a
...@@ -818,12 +818,12 @@ class ConvBertModel(ConvBertPreTrainedModel): ...@@ -818,12 +818,12 @@ class ConvBertModel(ConvBertPreTrainedModel):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None: if attention_mask is None:
......
...@@ -437,6 +437,17 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -437,6 +437,17 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device) loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
def test_model_for_input_embeds(self):
batch_size = 2
seq_length = 10
inputs_embeds = torch.rand([batch_size, seq_length, 768])
config = self.model_tester.get_config()
model = ConvBertModel(config=config)
model.to(torch_device)
model.eval()
result = model(inputs_embeds=inputs_embeds)
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
@require_torch @require_torch
class ConvBertModelIntegrationTest(unittest.TestCase): class ConvBertModelIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment