"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "08546656e8340877408165a29825fdc1d98a71b2"
Commit 162ba383 authored by thomwolf's avatar thomwolf
Browse files

fix model loading

parent 6dacc79d
......@@ -308,7 +308,8 @@ def main():
input_ids, input_mask, segment_ids, label_ids = batch
# define a new function to compute loss values for both output_modes
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
loss =
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
......
......@@ -193,7 +193,8 @@ class PreTrainedModel(nn.Module):
"""
state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', None)
from_tf = kwargs.pop('from_tf', False)
output_loading_info = kwargs.pop('output_loading_info', False)
# Load config
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
......@@ -239,6 +240,21 @@ class PreTrainedModel(nn.Module):
# Directly load from a TensorFlow checkpoint
return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
......@@ -279,6 +295,10 @@ class PreTrainedModel(nn.Module):
if hasattr(model, 'tie_weights'):
model.tie_weights() # make sure word embedding weights are still tied
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
return model, loading_info
return model
......
......@@ -17,21 +17,24 @@ from __future__ import division
from __future__ import print_function
import unittest
import logging
from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig)
model = BertModel.from_pretrained(model_name)
model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, PreTrainedModel)
for value in loading_info.values():
self.assertEqual(len(value), 0)
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment