"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "95b374952dc27d8511541d6f5a4e22c9ec11fb24"
Unverified Commit 123cce6f authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[modeling_utils] respect original dtype in _get_resized_lm_head (#14181)



* respect dtype in _get_resized_lm_head

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* consistency
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 88cd82e8
...@@ -789,9 +789,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -789,9 +789,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
# Build new embeddings # Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to( new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
self.device, dtype=old_embeddings.weight.dtype new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype)
)
# initialize all new embeddings (in particular added tokens) # initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings) self._init_weights(new_embeddings)
...@@ -862,7 +861,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -862,7 +861,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Build new lm head # Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device) new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(self.device, dtype=old_lm_head.weight.dtype)
# initialize new lm head (in particular added tokens) # initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head) self._init_weights(new_lm_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment