Unverified Commit c99f2547 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix device of masks in tests (#27887)

fix device of mask in tests
parent fc71e815
...@@ -104,7 +104,7 @@ class LlamaModelTester: ...@@ -104,7 +104,7 @@ class LlamaModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
...@@ -107,7 +107,7 @@ class MistralModelTester: ...@@ -107,7 +107,7 @@ class MistralModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
...@@ -104,7 +104,7 @@ class PersimmonModelTester: ...@@ -104,7 +104,7 @@ class PersimmonModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment