Unverified Commit 7c622482 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix resize_token_embeddings (#11572)

parent fe82b1bf
......@@ -682,7 +682,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(self.device)
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(
self.device, dtype=old_embeddings.weight.dtype
)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment