Unverified Commit 7c622482 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix resize_token_embeddings (#11572)

parent fe82b1bf
...@@ -682,7 +682,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -682,7 +682,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
# Build new embeddings # Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(self.device) new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(
self.device, dtype=old_embeddings.weight.dtype
)
# initialize all new embeddings (in particular added tokens) # initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings) self._init_weights(new_embeddings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment