"deployment/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "d628942b6656176d4d6b3c16405e4f640d62cf29"
Unverified Commit 197e7ce9 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix device issue in a `ConvBertModelTest` test (#21438)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 0df80282
...@@ -440,7 +440,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -440,7 +440,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_for_input_embeds(self): def test_model_for_input_embeds(self):
batch_size = 2 batch_size = 2
seq_length = 10 seq_length = 10
inputs_embeds = torch.rand([batch_size, seq_length, 768]) inputs_embeds = torch.rand([batch_size, seq_length, 768], device=torch_device)
config = self.model_tester.get_config() config = self.model_tester.get_config()
model = ConvBertModel(config=config) model = ConvBertModel(config=config)
model.to(torch_device) model.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment