Commit e7cfc46f authored by thomwolf's avatar thomwolf
Browse files

fix TransfoXLModel loading

parent 3c33499f
......@@ -959,7 +959,12 @@ class TransfoXLPreTrainedModel(nn.Module):
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='')
start_prefix = ''
if not hasattr(model, 'transformer') and any(s.startswith('transformer.') for s in state_dict.keys()):
start_prefix = 'transformer.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment