"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "a0090aa1841562c7e046b3d31a3988e96a97d4a4"
Unverified Commit 4bf50422 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Fix BART tests on GPU (#4298)

parent e4512aab
...@@ -886,7 +886,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -886,7 +886,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
if new_num_tokens <= old_num_tokens: if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens] new_bias = self.final_logits_bias[:, :new_num_tokens]
else: else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens)) extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias) self.register_buffer("final_logits_bias", new_bias)
......
...@@ -690,4 +690,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase): ...@@ -690,4 +690,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
# test that forward pass is just a lookup, there is no ignore padding logic # test that forward pass is just a lookup, there is no ignore padding logic
input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device) input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device)
no_cache_pad_zero = emb1(input_ids) no_cache_pad_zero = emb1(input_ids)
self.assertTrue(torch.allclose(torch.Tensor(self.desired_weights), no_cache_pad_zero[:3, :5], atol=1e-3)) self.assertTrue(
torch.allclose(
torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment