"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cce3089b65df1b60b560c2d86277c6298d266772"
Unverified Commit ef998529 authored by eukaryote's avatar eukaryote Committed by GitHub
Browse files

from_pretrained: convert DialoGPT format

DialoGPT checkpoints have "lm_head.decoder.weight" instead of "lm_head.weight". 

(see: https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/f6vmwuy?utm_source=share&utm_medium=web2x)
parent 7a9aae10
...@@ -417,6 +417,8 @@ class PreTrainedModel(nn.Module): ...@@ -417,6 +417,8 @@ class PreTrainedModel(nn.Module):
new_key = key.replace('gamma', 'weight') new_key = key.replace('gamma', 'weight')
if 'beta' in key: if 'beta' in key:
new_key = key.replace('beta', 'bias') new_key = key.replace('beta', 'bias')
if key == 'lm_head.decoder.weight':
new_key = 'lm_head.weight'
if new_key: if new_key:
old_keys.append(key) old_keys.append(key)
new_keys.append(new_key) new_keys.append(new_key)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment