Unverified Commit 2b2a2f8d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Bart] Fix: put dummy_inputs on correct device (#3398)

* Dummy inputs to model.device

* Move self.device to ModuleUtilsMixin
parent 1a5aefc9
...@@ -129,8 +129,8 @@ class PretrainedBartModel(PreTrainedModel): ...@@ -129,8 +129,8 @@ class PretrainedBartModel(PreTrainedModel):
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]]) input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids,) decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids)
dummy_inputs = { dummy_inputs = {
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": input_ids.ne(pad_token), "attention_mask": input_ids.ne(pad_token),
......
...@@ -108,6 +108,10 @@ class ModuleUtilsMixin: ...@@ -108,6 +108,10 @@ class ModuleUtilsMixin:
module.mem_rss_post_forward = 0 module.mem_rss_post_forward = 0
module.mem_rss_pre_forward = 0 module.mem_rss_pre_forward = 0
@property
def device(self):
return next(self.parameters()).device
class PreTrainedModel(nn.Module, ModuleUtilsMixin): class PreTrainedModel(nn.Module, ModuleUtilsMixin):
r""" Base class for all models. r""" Base class for all models.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment