Unverified Commit 123cce6f authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[modeling_utils] respect original dtype in _get_resized_lm_head (#14181)



* respect dtype in _get_resized_lm_head

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* consistency
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 88cd82e8
......@@ -789,9 +789,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(
self.device, dtype=old_embeddings.weight.dtype
)
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
......@@ -862,7 +861,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device)
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(self.device, dtype=old_lm_head.weight.dtype)
# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment