Commit d97b4176 authored by Lysandre's avatar Lysandre
Browse files

Correct device assignment

parent 9a3f9108
...@@ -894,6 +894,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -894,6 +894,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
] ]
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
self.cls.predictions.dense = resized_dense self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None: if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
...@@ -1008,6 +1009,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): ...@@ -1008,6 +1009,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
] ]
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
self.cls.predictions.dense = resized_dense self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None: if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment