"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d740351f7dfd6176e40efaeca694aca5622a55cd"
Commit e0f867a9 authored by LysandreJik's avatar LysandreJik
Browse files

XLNet bias fix on resize embeddings (cf #1124)

parent d7a4c325
...@@ -327,6 +327,14 @@ class PreTrainedModel(nn.Module): ...@@ -327,6 +327,14 @@ class PreTrainedModel(nn.Module):
else: else:
first_module.weight = second_module.weight first_module.weight = second_module.weight
if hasattr(first_module, 'bias'):
first_module.bias.data = torch.nn.functional.pad(
first_module.bias.data,
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
'constant',
0
)
def resize_token_embeddings(self, new_num_tokens=None): def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment