Commit f8fb4335 authored by thomwolf's avatar thomwolf
Browse files

clean up a little bit PT <=> TF conversion

parent bebaa140
...@@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_inputs = tf.constant(inputs_list) tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None, pt_model = pt_model_class(config)
config=config, pt_model.load_state_dict(torch.load(pytorch_checkpoint_path, map_location='cpu'),
state_dict=torch.load(pytorch_checkpoint_path, strict-False)
map_location='cpu')) pt_model.eval()
pt_inputs = torch.tensor(inputs_list) pt_inputs = torch.tensor(inputs_list)
with torch.no_grad(): with torch.no_grad():
pto = pt_model(pt_inputs) pto = pt_model(pt_inputs)
......
...@@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module): ...@@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path: if pretrained_model_name_or_path is not None and (
"albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path):
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " + logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
"https://github.com/google-research/google-research/issues/119 for more information.") "https://github.com/google-research/google-research/issues/119 for more information.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment