Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
54abc67a
Unverified
Commit
54abc67a
authored
Dec 22, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 22, 2019
Browse files
Merge pull request #2255 from aaugustin/implement-best-practices
Implement some Python best practices
parents
645713e2
c11b3e29
Changes
205
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1153 additions
and
885 deletions
+1153
-885
transformers/convert_bert_original_tf_checkpoint_to_pytorch.py
...formers/convert_bert_original_tf_checkpoint_to_pytorch.py
+21
-24
transformers/convert_bert_pytorch_checkpoint_to_original_tf.py
...formers/convert_bert_pytorch_checkpoint_to_original_tf.py
+27
-45
transformers/convert_gpt2_original_tf_checkpoint_to_pytorch.py
...formers/convert_gpt2_original_tf_checkpoint_to_pytorch.py
+20
-26
transformers/convert_openai_original_tf_checkpoint_to_pytorch.py
...rmers/convert_openai_original_tf_checkpoint_to_pytorch.py
+26
-26
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+347
-132
transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
...convert_roberta_original_pytorch_checkpoint_to_pytorch.py
+52
-65
transformers/convert_t5_original_tf_checkpoint_to_pytorch.py
transformers/convert_t5_original_tf_checkpoint_to_pytorch.py
+21
-24
transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
...s/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
+51
-36
transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
...ers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
+21
-23
transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py
...ormers/convert_xlnet_original_tf_checkpoint_to_pytorch.py
+48
-37
transformers/data/__init__.py
transformers/data/__init__.py
+23
-4
transformers/data/metrics/__init__.py
transformers/data/metrics/__init__.py
+4
-6
transformers/data/metrics/squad_metrics.py
transformers/data/metrics/squad_metrics.py
+83
-92
transformers/data/processors/__init__.py
transformers/data/processors/__init__.py
+8
-4
transformers/data/processors/glue.py
transformers/data/processors/glue.py
+130
-131
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+76
-53
transformers/data/processors/utils.py
transformers/data/processors/utils.py
+67
-50
transformers/data/processors/xnli.py
transformers/data/processors/xnli.py
+8
-7
transformers/file_utils.py
transformers/file_utils.py
+92
-61
transformers/hf_api.py
transformers/hf_api.py
+28
-39
No files found.
transformers/convert_bert_original_tf_checkpoint_to_pytorch.py
View file @
54abc67a
...
...
@@ -14,18 +14,19 @@
# limitations under the License.
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
logging
import
torch
from
transformers
import
BertConfig
,
BertForPreTraining
,
load_tf_weights_in_bert
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
bert_config_file
)
...
...
@@ -42,24 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
# Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
,
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
bert_config_file
,
args
.
pytorch_dump_path
)
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
bert_config_file
,
args
.
pytorch_dump_path
)
transformers/convert_bert_pytorch_checkpoint_to_original_tf.py
View file @
54abc67a
...
...
@@ -15,15 +15,17 @@
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
import
os
import
argparse
import
torch
import
os
import
numpy
as
np
import
tensorflow
as
tf
import
torch
from
transformers
import
BertModel
def
convert_pytorch_checkpoint_to_tf
(
model
:
BertModel
,
ckpt_dir
:
str
,
model_name
:
str
):
def
convert_pytorch_checkpoint_to_tf
(
model
:
BertModel
,
ckpt_dir
:
str
,
model_name
:
str
):
"""
:param model:BertModel Pytorch model instance to be converted
...
...
@@ -41,22 +43,17 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
N BertForQuestionAnswering
"""
tensors_to_transpose
=
(
"dense.weight"
,
"attention.self.query"
,
"attention.self.key"
,
"attention.self.value"
)
tensors_to_transpose
=
(
"dense.weight"
,
"attention.self.query"
,
"attention.self.key"
,
"attention.self.value"
)
var_map
=
(
(
'
layer.
'
,
'
layer_
'
),
(
'
word_embeddings.weight
'
,
'
word_embeddings
'
),
(
'
position_embeddings.weight
'
,
'
position_embeddings
'
),
(
'
token_type_embeddings.weight
'
,
'
token_type_embeddings
'
),
(
'.'
,
'/'
),
(
'
LayerNorm/weight
'
,
'
LayerNorm/gamma
'
),
(
'
LayerNorm/bias
'
,
'
LayerNorm/beta
'
),
(
'
weight
'
,
'
kernel
'
)
(
"
layer.
"
,
"
layer_
"
),
(
"
word_embeddings.weight
"
,
"
word_embeddings
"
),
(
"
position_embeddings.weight
"
,
"
position_embeddings
"
),
(
"
token_type_embeddings.weight
"
,
"
token_type_embeddings
"
),
(
"."
,
"/"
),
(
"
LayerNorm/weight
"
,
"
LayerNorm/gamma
"
),
(
"
LayerNorm/bias
"
,
"
LayerNorm/beta
"
),
(
"
weight
"
,
"
kernel
"
),
)
if
not
os
.
path
.
isdir
(
ckpt_dir
):
...
...
@@ -64,12 +61,12 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
state_dict
=
model
.
state_dict
()
def
to_tf_var_name
(
name
:
str
):
def
to_tf_var_name
(
name
:
str
):
for
patt
,
repl
in
iter
(
var_map
):
name
=
name
.
replace
(
patt
,
repl
)
return
'
bert/{}
'
.
format
(
name
)
return
"
bert/{}
"
.
format
(
name
)
def
create_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
,
session
:
tf
.
Session
):
def
create_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
,
session
:
tf
.
Session
):
tf_dtype
=
tf
.
dtypes
.
as_dtype
(
tensor
.
dtype
)
tf_var
=
tf
.
get_variable
(
dtype
=
tf_dtype
,
shape
=
tensor
.
shape
,
name
=
name
,
initializer
=
tf
.
zeros_initializer
())
session
.
run
(
tf
.
variables_initializer
([
tf_var
]))
...
...
@@ -94,36 +91,21 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
def
main
(
raw_args
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
required
=
True
,
help
=
"model name e.g. bert-base-uncased"
)
parser
.
add_argument
(
"--cache_dir"
,
type
=
str
,
default
=
None
,
required
=
False
,
help
=
"Directory containing pytorch model"
)
parser
.
add_argument
(
"--pytorch_model_path"
,
type
=
str
,
required
=
True
,
help
=
"/path/to/<pytorch-model-name>.bin"
)
parser
.
add_argument
(
"--tf_cache_dir"
,
type
=
str
,
required
=
True
,
help
=
"Directory in which to save tensorflow model"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
required
=
True
,
help
=
"model name e.g. bert-base-uncased"
)
parser
.
add_argument
(
"--cache_dir"
,
type
=
str
,
default
=
None
,
required
=
False
,
help
=
"Directory containing pytorch model"
)
parser
.
add_argument
(
"--pytorch_model_path"
,
type
=
str
,
required
=
True
,
help
=
"/path/to/<pytorch-model-name>.bin"
)
parser
.
add_argument
(
"--tf_cache_dir"
,
type
=
str
,
required
=
True
,
help
=
"Directory in which to save tensorflow model"
)
args
=
parser
.
parse_args
(
raw_args
)
model
=
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
=
args
.
model_name
,
state_dict
=
torch
.
load
(
args
.
pytorch_model_path
),
cache_dir
=
args
.
cache_dir
cache_dir
=
args
.
cache_dir
,
)
convert_pytorch_checkpoint_to_tf
(
model
=
model
,
ckpt_dir
=
args
.
tf_cache_dir
,
model_name
=
args
.
model_name
)
convert_pytorch_checkpoint_to_tf
(
model
=
model
,
ckpt_dir
=
args
.
tf_cache_dir
,
model_name
=
args
.
model_name
)
if
__name__
==
"__main__"
:
...
...
transformers/convert_gpt2_original_tf_checkpoint_to_pytorch.py
View file @
54abc67a
...
...
@@ -17,16 +17,14 @@
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
logging
from
io
import
open
import
torch
from
transformers
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
GPT2Config
,
GPT2Model
,
load_tf_weights_in_gpt2
)
from
transformers
import
CONFIG_NAME
,
WEIGHTS_NAME
,
GPT2Config
,
GPT2Model
,
load_tf_weights_in_gpt2
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
...
...
@@ -42,8 +40,8 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
load_tf_weights_in_gpt2
(
model
,
config
,
gpt2_checkpoint_path
)
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
"/"
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
"/"
+
CONFIG_NAME
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
...
...
@@ -53,23 +51,19 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--gpt2_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--gpt2_config_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional config json file corresponding to the pre-trained OpenAI model.
\n
"
"This specifies the model architecture."
)
# Required parameters
parser
.
add_argument
(
"--gpt2_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--gpt2_config_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional config json file corresponding to the pre-trained OpenAI model.
\n
"
"This specifies the model architecture."
,
)
args
=
parser
.
parse_args
()
convert_gpt2_checkpoint_to_pytorch
(
args
.
gpt2_checkpoint_path
,
args
.
gpt2_config_file
,
args
.
pytorch_dump_folder_path
)
convert_gpt2_checkpoint_to_pytorch
(
args
.
gpt2_checkpoint_path
,
args
.
gpt2_config_file
,
args
.
pytorch_dump_folder_path
)
transformers/convert_openai_original_tf_checkpoint_to_pytorch.py
View file @
54abc67a
...
...
@@ -17,16 +17,14 @@
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
logging
from
io
import
open
import
torch
from
transformers
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
OpenAIGPTConfig
,
OpenAIGPTModel
,
load_tf_weights_in_openai_gpt
)
from
transformers
import
CONFIG_NAME
,
WEIGHTS_NAME
,
OpenAIGPTConfig
,
OpenAIGPTModel
,
load_tf_weights_in_openai_gpt
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
...
...
@@ -42,8 +40,8 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
load_tf_weights_in_openai_gpt
(
model
,
config
,
openai_checkpoint_folder_path
)
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
"/"
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
"/"
+
CONFIG_NAME
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
...
...
@@ -53,23 +51,25 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--openai_checkpoint_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--openai_config_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional config json file corresponding to the pre-trained OpenAI model.
\n
"
"This specifies the model architecture."
)
# Required parameters
parser
.
add_argument
(
"--openai_checkpoint_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
,
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--openai_config_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional config json file corresponding to the pre-trained OpenAI model.
\n
"
"This specifies the model architecture."
,
)
args
=
parser
.
parse_args
()
convert_openai_checkpoint_to_pytorch
(
args
.
openai_checkpoint_folder_path
,
args
.
openai_config_file
,
args
.
pytorch_dump_folder_path
)
convert_openai_checkpoint_to_pytorch
(
args
.
openai_checkpoint_folder_path
,
args
.
openai_config_file
,
args
.
pytorch_dump_folder_path
)
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
54abc67a
...
...
@@ -14,92 +14,276 @@
# limitations under the License.
""" Convert pytorch checkpoints to TensorFlow """
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
os
import
argparse
import
tensorflow
as
tf
from
transformers
import
is_torch_available
,
cached_path
from
transformers
import
(
load_pytorch_checkpoint_in_tf2_model
,
BertConfig
,
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2Config
,
TFGPT2LMHeadModel
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetConfig
,
TFXLNetLMHeadModel
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMConfig
,
TFXLMWithLMHeadModel
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForSequenceClassification
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AlbertConfig
,
TFAlbertForMaskedLM
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
T5Config
,
TFT5WithLMHeadModel
,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
)
import
logging
import
os
from
transformers
import
(
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AlbertConfig
,
BertConfig
,
CTRLConfig
,
DistilBertConfig
,
GPT2Config
,
OpenAIGPTConfig
,
RobertaConfig
,
T5Config
,
TFAlbertForMaskedLM
,
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
,
TFCTRLLMHeadModel
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
TFGPT2LMHeadModel
,
TFOpenAIGPTLMHeadModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
TFT5WithLMHeadModel
,
TFTransfoXLLMHeadModel
,
TFXLMWithLMHeadModel
,
TFXLNetLMHeadModel
,
TransfoXLConfig
,
XLMConfig
,
XLNetConfig
,
cached_path
,
is_torch_available
,
load_pytorch_checkpoint_in_tf2_model
,
)
if
is_torch_available
():
import
torch
import
numpy
as
np
from
transformers
import
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
transformers
import
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
else
:
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
MODEL_CLASSES
=
{
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-cased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-base-cased-finetuned-mrpc'
:
(
BertConfig
,
TFBertForSequenceClassification
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'transfo-xl'
:
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'openai-gpt'
:
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta'
:
(
RobertaConfig
,
TFRobertaForMaskedLM
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'albert'
:
(
AlbertConfig
,
TFAlbertForMaskedLM
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
't5'
:
(
T5Config
,
TFT5WithLMHeadModel
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
),
"bert"
:
(
BertConfig
,
TFBertForPreTraining
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"bert-large-uncased-whole-word-masking-finetuned-squad"
:
(
BertConfig
,
TFBertForQuestionAnswering
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"bert-large-cased-whole-word-masking-finetuned-squad"
:
(
BertConfig
,
TFBertForQuestionAnswering
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"bert-base-cased-finetuned-mrpc"
:
(
BertConfig
,
TFBertForSequenceClassification
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"gpt2"
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"xlnet"
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"xlm"
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"transfo-xl"
:
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"openai-gpt"
:
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"roberta"
:
(
RobertaConfig
,
TFRobertaForMaskedLM
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"roberta-large-mnli"
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"distilbert"
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"distilbert-base-uncased-distilled-squad"
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"distilbert-base-uncased-distilled-squad"
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"ctrl"
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"albert"
:
(
AlbertConfig
,
TFAlbertForMaskedLM
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
"t5"
:
(
T5Config
,
TFT5WithLMHeadModel
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
,
),
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
...
...
@@ -116,17 +300,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Load weights from tf checkpoint
if
pytorch_checkpoint_path
in
aws_model_maps
:
pytorch_checkpoint_path
=
cached_path
(
aws_model_maps
[
pytorch_checkpoint_path
],
force_download
=
not
use_cached_models
)
pytorch_checkpoint_path
=
cached_path
(
aws_model_maps
[
pytorch_checkpoint_path
],
force_download
=
not
use_cached_models
)
# Load PyTorch checkpoint in tf2 model:
tf_model
=
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
if
compare_with_pt_model
:
tfo
=
tf_model
(
tf_model
.
dummy_inputs
,
training
=
False
)
# build the network
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
'
cpu
'
)
pt_model
=
pt_model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
)
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
"
cpu
"
)
pt_model
=
pt_model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
)
with
torch
.
no_grad
():
pto
=
pt_model
(
**
pt_model
.
dummy_inputs
)
...
...
@@ -139,11 +325,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Save pytorch-model
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
tf_model
.
save_weights
(
tf_dump_path
,
save_format
=
'
h5
'
)
tf_model
.
save_weights
(
tf_dump_path
,
save_format
=
"
h5
"
)
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
model_shortcut_names_or_path
=
None
,
config_shortcut_names_or_path
=
None
,
compare_with_pt_model
=
False
,
use_cached_models
=
False
,
remove_cached_files
=
False
,
only_convert_finetuned_models
=
False
):
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
model_shortcut_names_or_path
=
None
,
config_shortcut_names_or_path
=
None
,
compare_with_pt_model
=
False
,
use_cached_models
=
False
,
remove_cached_files
=
False
,
only_convert_finetuned_models
=
False
,
):
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
if
args_model_type
is
None
:
...
...
@@ -156,7 +350,9 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
print
(
" Converting model type {}/{}: {}"
.
format
(
j
,
len
(
model_types
),
model_type
))
print
(
"="
*
100
)
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type {}, should be one of {}."
.
format
(
model_type
,
list
(
MODEL_CLASSES
.
keys
())))
raise
ValueError
(
"Unrecognized model type {}, should be one of {}."
.
format
(
model_type
,
list
(
MODEL_CLASSES
.
keys
()))
)
config_class
,
model_class
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
...
...
@@ -166,9 +362,10 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
config_shortcut_names_or_path
=
model_shortcut_names_or_path
for
i
,
(
model_shortcut_name
,
config_shortcut_name
)
in
enumerate
(
zip
(
model_shortcut_names_or_path
,
config_shortcut_names_or_path
),
start
=
1
):
zip
(
model_shortcut_names_or_path
,
config_shortcut_names_or_path
),
start
=
1
):
print
(
"-"
*
100
)
if
'
-squad
'
in
model_shortcut_name
or
'
-mrpc
'
in
model_shortcut_name
or
'
-mnli
'
in
model_shortcut_name
:
if
"
-squad
"
in
model_shortcut_name
or
"
-mrpc
"
in
model_shortcut_name
or
"
-mnli
"
in
model_shortcut_name
:
if
not
only_convert_finetuned_models
:
print
(
" Skipping finetuned checkpoint {}"
.
format
(
model_shortcut_name
))
continue
...
...
@@ -176,7 +373,11 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
elif
only_convert_finetuned_models
:
print
(
" Skipping not finetuned checkpoint {}"
.
format
(
model_shortcut_name
))
continue
print
(
" Converting checkpoint {}/{}: {} - model_type {}"
.
format
(
i
,
len
(
aws_config_map
),
model_shortcut_name
,
model_type
))
print
(
" Converting checkpoint {}/{}: {} - model_type {}"
.
format
(
i
,
len
(
aws_config_map
),
model_shortcut_name
,
model_type
)
)
print
(
"-"
*
100
)
if
config_shortcut_name
in
aws_config_map
:
...
...
@@ -190,13 +391,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
model_file
=
cached_path
(
model_shortcut_name
,
force_download
=
not
use_cached_models
)
if
os
.
path
.
isfile
(
model_shortcut_name
):
model_shortcut_name
=
'
converted_model
'
model_shortcut_name
=
"
converted_model
"
convert_pt_checkpoint_to_tf
(
model_type
=
model_type
,
convert_pt_checkpoint_to_tf
(
model_type
=
model_type
,
pytorch_checkpoint_path
=
model_file
,
config_file
=
config_file
,
tf_dump_path
=
os
.
path
.
join
(
tf_dump_path
,
model_shortcut_name
+
'-tf_model.h5'
),
compare_with_pt_model
=
compare_with_pt_model
)
tf_dump_path
=
os
.
path
.
join
(
tf_dump_path
,
model_shortcut_name
+
"-tf_model.h5"
),
compare_with_pt_model
=
compare_with_pt_model
,
)
if
remove_cached_files
:
os
.
remove
(
config_file
)
os
.
remove
(
model_file
)
...
...
@@ -204,40 +407,48 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output Tensorflow dump file."
)
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
help
=
"Model type selected in the list of {}. If not given, will download and convert all the models from AWS."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
parser
.
add_argument
(
"--pytorch_checkpoint_path"
,
default
=
None
,
type
=
str
,
help
=
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
"If not given, will download and convert all the checkpoints from AWS."
)
parser
.
add_argument
(
"--config_file"
,
default
=
None
,
type
=
str
,
help
=
"The config json file corresponding to the pre-trained model.
\n
"
# Required parameters
parser
.
add_argument
(
"--tf_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output Tensorflow dump file."
)
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
help
=
"Model type selected in the list of {}. If not given, will download and convert all the models from AWS."
.
format
(
list
(
MODEL_CLASSES
.
keys
())
),
)
parser
.
add_argument
(
"--pytorch_checkpoint_path"
,
default
=
None
,
type
=
str
,
help
=
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
"If not given, will download and convert all the checkpoints from AWS."
,
)
parser
.
add_argument
(
"--config_file"
,
default
=
None
,
type
=
str
,
help
=
"The config json file corresponding to the pre-trained model.
\n
"
"This specifies the model architecture. If not given and "
"--pytorch_checkpoint_path is not given or is a shortcut name"
"use the configuration associated to the shortcut name on the AWS"
)
parser
.
add_argument
(
"--compare_with_pt_model"
,
action
=
'store_true'
,
help
=
"Compare Tensorflow and PyTorch model predictions."
)
parser
.
add_argument
(
"--use_cached_models"
,
action
=
'store_true'
,
help
=
"Use cached models if possible instead of updating to latest checkpoint versions."
)
parser
.
add_argument
(
"--remove_cached_files"
,
action
=
'store_true'
,
help
=
"Remove pytorch models after conversion (save memory when converting in batches)."
)
parser
.
add_argument
(
"--only_convert_finetuned_models"
,
action
=
'store_true'
,
help
=
"Only convert finetuned models."
)
"use the configuration associated to the shortcut name on the AWS"
,
)
parser
.
add_argument
(
"--compare_with_pt_model"
,
action
=
"store_true"
,
help
=
"Compare Tensorflow and PyTorch model predictions."
)
parser
.
add_argument
(
"--use_cached_models"
,
action
=
"store_true"
,
help
=
"Use cached models if possible instead of updating to latest checkpoint versions."
,
)
parser
.
add_argument
(
"--remove_cached_files"
,
action
=
"store_true"
,
help
=
"Remove pytorch models after conversion (save memory when converting in batches)."
,
)
parser
.
add_argument
(
"--only_convert_finetuned_models"
,
action
=
"store_true"
,
help
=
"Only convert finetuned models."
)
args
=
parser
.
parse_args
()
# if args.pytorch_checkpoint_path is not None:
...
...
@@ -248,11 +459,15 @@ if __name__ == "__main__":
# compare_with_pt_model=args.compare_with_pt_model,
# use_cached_models=args.use_cached_models)
# else:
convert_all_pt_checkpoints_to_tf
(
args
.
model_type
.
lower
()
if
args
.
model_type
is
not
None
else
None
,
convert_all_pt_checkpoints_to_tf
(
args
.
model_type
.
lower
()
if
args
.
model_type
is
not
None
else
None
,
args
.
tf_dump_path
,
model_shortcut_names_or_path
=
[
args
.
pytorch_checkpoint_path
]
if
args
.
pytorch_checkpoint_path
is
not
None
else
None
,
model_shortcut_names_or_path
=
[
args
.
pytorch_checkpoint_path
]
if
args
.
pytorch_checkpoint_path
is
not
None
else
None
,
config_shortcut_names_or_path
=
[
args
.
config_file
]
if
args
.
config_file
is
not
None
else
None
,
compare_with_pt_model
=
args
.
compare_with_pt_model
,
use_cached_models
=
args
.
use_cached_models
,
remove_cached_files
=
args
.
remove_cached_files
,
only_convert_finetuned_models
=
args
.
only_convert_finetuned_models
)
only_convert_finetuned_models
=
args
.
only_convert_finetuned_models
,
)
transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
View file @
54abc67a
...
...
@@ -18,32 +18,33 @@ from __future__ import absolute_import, division, print_function
import
argparse
import
logging
import
numpy
as
np
import
torch
import
pathlib
import
fairseq
import
torch
from
fairseq.models.roberta
import
RobertaModel
as
FairseqRobertaModel
from
fairseq.modules
import
TransformerSentenceEncoderLayer
from
packaging
import
version
from
transformers.modeling_bert
import
(
BertConfig
,
BertIntermediate
,
BertLayer
,
BertOutput
,
BertSelfAttention
,
BertSelfOutput
,
)
from
transformers.modeling_roberta
import
RobertaForMaskedLM
,
RobertaForSequenceClassification
if
version
.
parse
(
fairseq
.
__version__
)
<
version
.
parse
(
"0.9.0"
):
raise
Exception
(
"requires fairseq >= 0.9.0"
)
from
fairseq.models.roberta
import
RobertaModel
as
FairseqRobertaModel
from
fairseq.modules
import
TransformerSentenceEncoderLayer
from
transformers.modeling_bert
import
(
BertConfig
,
BertEncoder
,
BertIntermediate
,
BertLayer
,
BertModel
,
BertOutput
,
BertSelfAttention
,
BertSelfOutput
)
from
transformers.modeling_roberta
import
(
RobertaEmbeddings
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
RobertaModel
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
SAMPLE_TEXT
=
'
Hello world! cécé herlolip
'
SAMPLE_TEXT
=
"
Hello world! cécé herlolip
"
def
convert_roberta_checkpoint_to_pytorch
(
roberta_checkpoint_path
,
pytorch_dump_folder_path
,
classification_head
):
...
...
@@ -74,7 +75,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
# Embeddings
model
.
roberta
.
embeddings
.
word_embeddings
.
weight
=
roberta_sent_encoder
.
embed_tokens
.
weight
model
.
roberta
.
embeddings
.
position_embeddings
.
weight
=
roberta_sent_encoder
.
embed_positions
.
weight
model
.
roberta
.
embeddings
.
token_type_embeddings
.
weight
.
data
=
torch
.
zeros_like
(
model
.
roberta
.
embeddings
.
token_type_embeddings
.
weight
)
# just zero them out b/c RoBERTa doesn't use them.
model
.
roberta
.
embeddings
.
token_type_embeddings
.
weight
.
data
=
torch
.
zeros_like
(
model
.
roberta
.
embeddings
.
token_type_embeddings
.
weight
)
# just zero them out b/c RoBERTa doesn't use them.
model
.
roberta
.
embeddings
.
LayerNorm
.
weight
=
roberta_sent_encoder
.
emb_layer_norm
.
weight
model
.
roberta
.
embeddings
.
LayerNorm
.
bias
=
roberta_sent_encoder
.
emb_layer_norm
.
bias
...
...
@@ -83,13 +86,13 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
layer
:
BertLayer
=
model
.
roberta
.
encoder
.
layer
[
i
]
roberta_layer
:
TransformerSentenceEncoderLayer
=
roberta_sent_encoder
.
layers
[
i
]
#
##
self attention
# self attention
self_attn
:
BertSelfAttention
=
layer
.
attention
.
self
assert
(
roberta_layer
.
self_attn
.
k_proj
.
weight
.
data
.
shape
==
\
roberta_layer
.
self_attn
.
q_proj
.
weight
.
data
.
shape
==
\
roberta_layer
.
self_attn
.
v_proj
.
weight
.
data
.
shape
==
\
torch
.
Size
((
config
.
hidden_size
,
config
.
hidden_size
))
assert
(
roberta_layer
.
self_attn
.
k_proj
.
weight
.
data
.
shape
==
roberta_layer
.
self_attn
.
q_proj
.
weight
.
data
.
shape
==
roberta_layer
.
self_attn
.
v_proj
.
weight
.
data
.
shape
==
torch
.
Size
((
config
.
hidden_size
,
config
.
hidden_size
))
)
self_attn
.
query
.
weight
.
data
=
roberta_layer
.
self_attn
.
q_proj
.
weight
...
...
@@ -99,40 +102,34 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
self_attn
.
value
.
weight
.
data
=
roberta_layer
.
self_attn
.
v_proj
.
weight
self_attn
.
value
.
bias
.
data
=
roberta_layer
.
self_attn
.
v_proj
.
bias
#
##
self-attention output
# self-attention output
self_output
:
BertSelfOutput
=
layer
.
attention
.
output
assert
(
self_output
.
dense
.
weight
.
shape
==
roberta_layer
.
self_attn
.
out_proj
.
weight
.
shape
)
assert
self_output
.
dense
.
weight
.
shape
==
roberta_layer
.
self_attn
.
out_proj
.
weight
.
shape
self_output
.
dense
.
weight
=
roberta_layer
.
self_attn
.
out_proj
.
weight
self_output
.
dense
.
bias
=
roberta_layer
.
self_attn
.
out_proj
.
bias
self_output
.
LayerNorm
.
weight
=
roberta_layer
.
self_attn_layer_norm
.
weight
self_output
.
LayerNorm
.
bias
=
roberta_layer
.
self_attn_layer_norm
.
bias
#
##
intermediate
# intermediate
intermediate
:
BertIntermediate
=
layer
.
intermediate
assert
(
intermediate
.
dense
.
weight
.
shape
==
roberta_layer
.
fc1
.
weight
.
shape
)
assert
intermediate
.
dense
.
weight
.
shape
==
roberta_layer
.
fc1
.
weight
.
shape
intermediate
.
dense
.
weight
=
roberta_layer
.
fc1
.
weight
intermediate
.
dense
.
bias
=
roberta_layer
.
fc1
.
bias
#
##
output
# output
bert_output
:
BertOutput
=
layer
.
output
assert
(
bert_output
.
dense
.
weight
.
shape
==
roberta_layer
.
fc2
.
weight
.
shape
)
assert
bert_output
.
dense
.
weight
.
shape
==
roberta_layer
.
fc2
.
weight
.
shape
bert_output
.
dense
.
weight
=
roberta_layer
.
fc2
.
weight
bert_output
.
dense
.
bias
=
roberta_layer
.
fc2
.
bias
bert_output
.
LayerNorm
.
weight
=
roberta_layer
.
final_layer_norm
.
weight
bert_output
.
LayerNorm
.
bias
=
roberta_layer
.
final_layer_norm
.
bias
#
###
end of layer
# end of layer
if
classification_head
:
model
.
classifier
.
dense
.
weight
=
roberta
.
model
.
classification_heads
[
'
mnli
'
].
dense
.
weight
model
.
classifier
.
dense
.
bias
=
roberta
.
model
.
classification_heads
[
'
mnli
'
].
dense
.
bias
model
.
classifier
.
out_proj
.
weight
=
roberta
.
model
.
classification_heads
[
'
mnli
'
].
out_proj
.
weight
model
.
classifier
.
out_proj
.
bias
=
roberta
.
model
.
classification_heads
[
'
mnli
'
].
out_proj
.
bias
model
.
classifier
.
dense
.
weight
=
roberta
.
model
.
classification_heads
[
"
mnli
"
].
dense
.
weight
model
.
classifier
.
dense
.
bias
=
roberta
.
model
.
classification_heads
[
"
mnli
"
].
dense
.
bias
model
.
classifier
.
out_proj
.
weight
=
roberta
.
model
.
classification_heads
[
"
mnli
"
].
out_proj
.
weight
model
.
classifier
.
out_proj
.
bias
=
roberta
.
model
.
classification_heads
[
"
mnli
"
].
out_proj
.
bias
else
:
# LM Head
model
.
lm_head
.
dense
.
weight
=
roberta
.
model
.
decoder
.
lm_head
.
dense
.
weight
...
...
@@ -147,17 +144,14 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
our_output
=
model
(
input_ids
)[
0
]
if
classification_head
:
their_output
=
roberta
.
model
.
classification_heads
[
'
mnli
'
](
roberta
.
extract_features
(
input_ids
))
their_output
=
roberta
.
model
.
classification_heads
[
"
mnli
"
](
roberta
.
extract_features
(
input_ids
))
else
:
their_output
=
roberta
.
model
(
input_ids
)[
0
]
print
(
our_output
.
shape
,
their_output
.
shape
)
max_absolute_diff
=
torch
.
max
(
torch
.
abs
(
our_output
-
their_output
)).
item
()
print
(
f
"max_absolute_diff =
{
max_absolute_diff
}
"
)
# ~ 1e-7
success
=
torch
.
allclose
(
our_output
,
their_output
,
atol
=
1e-3
)
print
(
"Do both models output the same tensors?"
,
"🔥"
if
success
else
"💩"
)
print
(
"Do both models output the same tensors?"
,
"🔥"
if
success
else
"💩"
)
if
not
success
:
raise
Exception
(
"Something went wRoNg"
)
...
...
@@ -168,24 +162,17 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--roberta_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the official PyTorch dump."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--classification_head"
,
action
=
"store_true"
,
help
=
"Whether to convert a final classification head."
)
# Required parameters
parser
.
add_argument
(
"--roberta_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the official PyTorch dump."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--classification_head"
,
action
=
"store_true"
,
help
=
"Whether to convert a final classification head."
)
args
=
parser
.
parse_args
()
convert_roberta_checkpoint_to_pytorch
(
args
.
roberta_checkpoint_path
,
args
.
pytorch_dump_folder_path
,
args
.
classification_head
args
.
roberta_checkpoint_path
,
args
.
pytorch_dump_folder_path
,
args
.
classification_head
)
transformers/convert_t5_original_tf_checkpoint_to_pytorch.py
View file @
54abc67a
...
...
@@ -14,18 +14,19 @@
# limitations under the License.
"""Convert T5 checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
logging
import
torch
from
transformers
import
T5Config
,
T5Model
,
load_tf_weights_in_t5
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
config
=
T5Config
.
from_json_file
(
config_file
)
...
...
@@ -42,24 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained T5 model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
# Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained T5 model.
\n
"
"This specifies the model architecture."
,
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
config_file
,
args
.
pytorch_dump_path
)
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
config_file
,
args
.
pytorch_dump_path
)
transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
View file @
54abc67a
...
...
@@ -17,6 +17,7 @@
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
logging
import
os
import
sys
from
io
import
open
...
...
@@ -24,44 +25,48 @@ from io import open
import
torch
import
transformers.tokenization_transfo_xl
as
data_utils
from
transformers
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
TransfoXLConfig
,
TransfoXLLMHeadModel
,
load_tf_weights_in_transfo_xl
,
)
from
transformers.tokenization_transfo_xl
import
CORPUS_NAME
,
VOCAB_FILES_NAMES
from
transformers
import
CONFIG_NAME
,
WEIGHTS_NAME
from
transformers
import
(
TransfoXLConfig
,
TransfoXLLMHeadModel
,
load_tf_weights_in_transfo_xl
)
from
transformers.tokenization_transfo_xl
import
(
CORPUS_NAME
,
VOCAB_FILES_NAMES
)
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
# We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils
.
Vocab
=
data_utils
.
TransfoXLTokenizer
data_utils
.
Corpus
=
data_utils
.
TransfoXLCorpus
sys
.
modules
[
'data_utils'
]
=
data_utils
sys
.
modules
[
'vocabulary'
]
=
data_utils
sys
.
modules
[
"data_utils"
]
=
data_utils
sys
.
modules
[
"vocabulary"
]
=
data_utils
def
convert_transfo_xl_checkpoint_to_pytorch
(
tf_checkpoint_path
,
transfo_xl_config_file
,
pytorch_dump_folder_path
,
transfo_xl_dataset_file
):
def
convert_transfo_xl_checkpoint_to_pytorch
(
tf_checkpoint_path
,
transfo_xl_config_file
,
pytorch_dump_folder_path
,
transfo_xl_dataset_file
):
if
transfo_xl_dataset_file
:
# Convert a pre-processed corpus (see original TensorFlow repo)
with
open
(
transfo_xl_dataset_file
,
"rb"
)
as
fp
:
corpus
=
pickle
.
load
(
fp
,
encoding
=
"latin1"
)
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_FILES_NAMES
[
'
pretrained_vocab_file
'
]
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
"/"
+
VOCAB_FILES_NAMES
[
"
pretrained_vocab_file
"
]
print
(
"Save vocabulary to {}"
.
format
(
pytorch_vocab_dump_path
))
corpus_vocab_dict
=
corpus
.
vocab
.
__dict__
torch
.
save
(
corpus_vocab_dict
,
pytorch_vocab_dump_path
)
corpus_dict_no_vocab
=
corpus
.
__dict__
corpus_dict_no_vocab
.
pop
(
'
vocab
'
,
None
)
pytorch_dataset_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CORPUS_NAME
corpus_dict_no_vocab
.
pop
(
"
vocab
"
,
None
)
pytorch_dataset_dump_path
=
pytorch_dump_folder_path
+
"/"
+
CORPUS_NAME
print
(
"Save dataset to {}"
.
format
(
pytorch_dataset_dump_path
))
torch
.
save
(
corpus_dict_no_vocab
,
pytorch_dataset_dump_path
)
...
...
@@ -92,26 +97,36 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
)
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
""
,
type
=
str
,
help
=
"An optional path to a TensorFlow checkpoint path to be converted."
)
parser
.
add_argument
(
"--transfo_xl_config_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--transfo_xl_dataset_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional dataset file to be converted in a vocabulary."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
,
)
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
""
,
type
=
str
,
help
=
"An optional path to a TensorFlow checkpoint path to be converted."
,
)
parser
.
add_argument
(
"--transfo_xl_config_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
,
)
parser
.
add_argument
(
"--transfo_xl_dataset_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional dataset file to be converted in a vocabulary."
,
)
args
=
parser
.
parse_args
()
convert_transfo_xl_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
convert_transfo_xl_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
transfo_xl_config_file
,
args
.
pytorch_dump_folder_path
,
args
.
transfo_xl_dataset_file
)
args
.
transfo_xl_dataset_file
,
)
transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
View file @
54abc67a
...
...
@@ -18,41 +18,43 @@ from __future__ import absolute_import, division, print_function
import
argparse
import
json
import
logging
from
io
import
open
import
torch
import
numpy
import
torch
from
transformers
import
CONFIG_NAME
,
WEIGHTS_NAME
from
transformers.tokenization_xlm
import
VOCAB_FILES_NAMES
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_xlm_checkpoint_to_pytorch
(
xlm_checkpoint_path
,
pytorch_dump_folder_path
):
# Load checkpoint
chkpt
=
torch
.
load
(
xlm_checkpoint_path
,
map_location
=
'
cpu
'
)
chkpt
=
torch
.
load
(
xlm_checkpoint_path
,
map_location
=
"
cpu
"
)
state_dict
=
chkpt
[
'
model
'
]
state_dict
=
chkpt
[
"
model
"
]
# We have the base model one level deeper than the original XLM repository
two_levels_state_dict
=
{}
for
k
,
v
in
state_dict
.
items
():
if
'
pred_layer
'
in
k
:
if
"
pred_layer
"
in
k
:
two_levels_state_dict
[
k
]
=
v
else
:
two_levels_state_dict
[
'
transformer.
'
+
k
]
=
v
two_levels_state_dict
[
"
transformer.
"
+
k
]
=
v
config
=
chkpt
[
'
params
'
]
config
=
chkpt
[
"
params
"
]
config
=
dict
((
n
,
v
)
for
n
,
v
in
config
.
items
()
if
not
isinstance
(
v
,
(
torch
.
FloatTensor
,
numpy
.
ndarray
)))
vocab
=
chkpt
[
'
dico_word2id
'
]
vocab
=
dict
((
s
+
'
</w>
'
if
s
.
find
(
'
@@
'
)
==
-
1
and
i
>
13
else
s
.
replace
(
'
@@
'
,
''
),
i
)
for
s
,
i
in
vocab
.
items
())
vocab
=
chkpt
[
"
dico_word2id
"
]
vocab
=
dict
((
s
+
"
</w>
"
if
s
.
find
(
"
@@
"
)
==
-
1
and
i
>
13
else
s
.
replace
(
"
@@
"
,
""
),
i
)
for
s
,
i
in
vocab
.
items
())
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_FILES_NAMES
[
'
vocab_file
'
]
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
"/"
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
"/"
+
CONFIG_NAME
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
"/"
+
VOCAB_FILES_NAMES
[
"
vocab_file
"
]
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
two_levels_state_dict
,
pytorch_weights_dump_path
)
...
...
@@ -68,16 +70,12 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--xlm_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the official PyTorch dump."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
# Required parameters
parser
.
add_argument
(
"--xlm_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the official PyTorch dump."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_xlm_checkpoint_to_pytorch
(
args
.
xlm_checkpoint_path
,
args
.
pytorch_dump_folder_path
)
transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py
View file @
54abc67a
...
...
@@ -14,19 +14,24 @@
# limitations under the License.
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
os
import
argparse
import
logging
import
os
import
torch
from
transformers
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
from
transformers
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
XLNetConfig
,
XLNetLMHeadModel
,
XLNetForQuestionAnswering
,
XLNetForQuestionAnswering
,
XLNetForSequenceClassification
,
load_tf_weights_in_xlnet
)
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
,
)
GLUE_TASKS_NUM_LABELS
=
{
"cola"
:
2
,
...
...
@@ -40,10 +45,13 @@ GLUE_TASKS_NUM_LABELS = {
"wnli"
:
2
,
}
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_folder_path
,
finetuning_task
=
None
):
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_folder_path
,
finetuning_task
=
None
):
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
...
...
@@ -53,7 +61,7 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
config
.
finetuning_task
=
finetuning_task
config
.
num_labels
=
GLUE_TASKS_NUM_LABELS
[
finetuning_task
]
model
=
XLNetForSequenceClassification
(
config
)
elif
'
squad
'
in
finetuning_task
:
elif
"
squad
"
in
finetuning_task
:
config
.
finetuning_task
=
finetuning_task
model
=
XLNetForQuestionAnswering
(
config
)
else
:
...
...
@@ -74,31 +82,34 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--xlnet_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
)
parser
.
add_argument
(
"--finetuning_task"
,
default
=
None
,
type
=
str
,
help
=
"Name of a task on which the XLNet TensorFloaw model was fine-tuned"
)
# Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--xlnet_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
,
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
,
)
parser
.
add_argument
(
"--finetuning_task"
,
default
=
None
,
type
=
str
,
help
=
"Name of a task on which the XLNet TensorFloaw model was fine-tuned"
,
)
args
=
parser
.
parse_args
()
print
(
args
)
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_folder_path
,
args
.
finetuning_task
)
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_folder_path
,
args
.
finetuning_task
)
transformers/data/__init__.py
View file @
54abc67a
from
.processors
import
InputExample
,
InputFeatures
,
DataProcessor
,
SquadFeatures
,
SingleSentenceClassificationProcessor
from
.processors
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.processors
import
squad_convert_examples_to_features
,
SquadExample
,
SquadV1Processor
,
SquadV2Processor
from
.processors
import
xnli_output_modes
,
xnli_processors
,
xnli_tasks_num_labels
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from
.metrics
import
is_sklearn_available
from
.processors
import
(
DataProcessor
,
InputExample
,
InputFeatures
,
SingleSentenceClassificationProcessor
,
SquadExample
,
SquadFeatures
,
SquadV1Processor
,
SquadV2Processor
,
glue_convert_examples_to_features
,
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
squad_convert_examples_to_features
,
xnli_output_modes
,
xnli_processors
,
xnli_tasks_num_labels
,
)
if
is_sklearn_available
():
from
.metrics
import
glue_compute_metrics
,
xnli_compute_metrics
transformers/data/metrics/__init__.py
View file @
54abc67a
...
...
@@ -14,29 +14,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
csv
import
sys
import
logging
logger
=
logging
.
getLogger
(
__name__
)
try
:
from
scipy.stats
import
pearsonr
,
spearmanr
from
sklearn.metrics
import
matthews_corrcoef
,
f1_score
_has_sklearn
=
True
except
(
AttributeError
,
ImportError
)
as
e
:
logger
.
warning
(
"To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html"
)
_has_sklearn
=
False
def
is_sklearn_available
():
return
_has_sklearn
if
_has_sklearn
:
def
simple_accuracy
(
preds
,
labels
):
return
(
preds
==
labels
).
mean
()
def
acc_and_f1
(
preds
,
labels
):
acc
=
simple_accuracy
(
preds
,
labels
)
f1
=
f1_score
(
y_true
=
labels
,
y_pred
=
preds
)
...
...
@@ -46,7 +47,6 @@ if _has_sklearn:
"acc_and_f1"
:
(
acc
+
f1
)
/
2
,
}
def
pearson_and_spearman
(
preds
,
labels
):
pearson_corr
=
pearsonr
(
preds
,
labels
)[
0
]
spearman_corr
=
spearmanr
(
preds
,
labels
)[
0
]
...
...
@@ -56,7 +56,6 @@ if _has_sklearn:
"corr"
:
(
pearson_corr
+
spearman_corr
)
/
2
,
}
def
glue_compute_metrics
(
task_name
,
preds
,
labels
):
assert
len
(
preds
)
==
len
(
labels
)
if
task_name
==
"cola"
:
...
...
@@ -82,7 +81,6 @@ if _has_sklearn:
else
:
raise
KeyError
(
task_name
)
def
xnli_compute_metrics
(
task_name
,
preds
,
labels
):
assert
len
(
preds
)
==
len
(
labels
)
if
task_name
==
"xnli"
:
...
...
transformers/data/metrics/squad_metrics.py
View file @
54abc67a
...
...
@@ -8,35 +8,37 @@ that a question is unanswerable.
"""
import
collections
import
json
import
logging
import
math
import
collections
from
io
import
open
from
tqdm
import
tqdm
import
string
import
re
import
string
from
io
import
open
from
transformers.tokenization_bert
import
BasicTokenizer
from
transformers.tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
logger
=
logging
.
getLogger
(
__name__
)
def
normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
regex
=
re
.
compile
(
r
'
\b(a|an|the)\b
'
,
re
.
UNICODE
)
return
re
.
sub
(
regex
,
' '
,
text
)
regex
=
re
.
compile
(
r
"
\b(a|an|the)\b
"
,
re
.
UNICODE
)
return
re
.
sub
(
regex
,
" "
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
return
" "
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
''
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
return
""
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
...
...
@@ -75,14 +77,14 @@ def get_raw_scores(examples, preds):
for
example
in
examples
:
qas_id
=
example
.
qas_id
gold_answers
=
[
answer
[
'
text
'
]
for
answer
in
example
.
answers
if
normalize_answer
(
answer
[
'
text
'
])]
gold_answers
=
[
answer
[
"
text
"
]
for
answer
in
example
.
answers
if
normalize_answer
(
answer
[
"
text
"
])]
if
not
gold_answers
:
# For unanswerable questions, only correct answer is empty string
gold_answers
=
[
''
]
gold_answers
=
[
""
]
if
qas_id
not
in
preds
:
print
(
'
Missing prediction for %s
'
%
qas_id
)
print
(
"
Missing prediction for %s
"
%
qas_id
)
continue
prediction
=
preds
[
qas_id
]
...
...
@@ -106,23 +108,27 @@ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
def
make_eval_dict
(
exact_scores
,
f1_scores
,
qid_list
=
None
):
if
not
qid_list
:
total
=
len
(
exact_scores
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
.
values
())
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
.
values
())
/
total
),
(
'total'
,
total
),
])
return
collections
.
OrderedDict
(
[
(
"exact"
,
100.0
*
sum
(
exact_scores
.
values
())
/
total
),
(
"f1"
,
100.0
*
sum
(
f1_scores
.
values
())
/
total
),
(
"total"
,
total
),
]
)
else
:
total
=
len
(
qid_list
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'total'
,
total
),
])
return
collections
.
OrderedDict
(
[
(
"exact"
,
100.0
*
sum
(
exact_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
"f1"
,
100.0
*
sum
(
f1_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
"total"
,
total
),
]
)
def
merge_eval
(
main_eval
,
new_eval
,
prefix
):
for
k
in
new_eval
:
main_eval
[
'
%s_%s
'
%
(
prefix
,
k
)]
=
new_eval
[
k
]
main_eval
[
"
%s_%s
"
%
(
prefix
,
k
)]
=
new_eval
[
k
]
def
find_best_thresh_v2
(
preds
,
scores
,
na_probs
,
qid_to_has_ans
):
...
...
@@ -160,16 +166,14 @@ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
def
find_all_best_thresh_v2
(
main_eval
,
preds
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
best_exact
,
exact_thresh
,
has_ans_exact
=
find_best_thresh_v2
(
preds
,
exact_raw
,
na_probs
,
qid_to_has_ans
)
best_f1
,
f1_thresh
,
has_ans_f1
=
find_best_thresh_v2
(
preds
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
main_eval
[
'best_exact'
]
=
best_exact
main_eval
[
'best_exact_thresh'
]
=
exact_thresh
main_eval
[
'best_f1'
]
=
best_f1
main_eval
[
'best_f1_thresh'
]
=
f1_thresh
main_eval
[
'has_ans_exact'
]
=
has_ans_exact
main_eval
[
'has_ans_f1'
]
=
has_ans_f1
best_exact
,
exact_thresh
,
has_ans_exact
=
find_best_thresh_v2
(
preds
,
exact_raw
,
na_probs
,
qid_to_has_ans
)
best_f1
,
f1_thresh
,
has_ans_f1
=
find_best_thresh_v2
(
preds
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
main_eval
[
"best_exact"
]
=
best_exact
main_eval
[
"best_exact_thresh"
]
=
exact_thresh
main_eval
[
"best_f1"
]
=
best_f1
main_eval
[
"best_f1_thresh"
]
=
f1_thresh
main_eval
[
"has_ans_exact"
]
=
has_ans_exact
main_eval
[
"has_ans_f1"
]
=
has_ans_f1
def
find_best_thresh
(
preds
,
scores
,
na_probs
,
qid_to_has_ans
):
...
...
@@ -199,10 +203,10 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h
best_exact
,
exact_thresh
=
find_best_thresh
(
preds
,
exact_raw
,
na_probs
,
qid_to_has_ans
)
best_f1
,
f1_thresh
=
find_best_thresh
(
preds
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
main_eval
[
'
best_exact
'
]
=
best_exact
main_eval
[
'
best_exact_thresh
'
]
=
exact_thresh
main_eval
[
'
best_f1
'
]
=
best_f1
main_eval
[
'
best_f1_thresh
'
]
=
f1_thresh
main_eval
[
"
best_exact
"
]
=
best_exact
main_eval
[
"
best_exact_thresh
"
]
=
exact_thresh
main_eval
[
"
best_f1
"
]
=
best_f1
main_eval
[
"
best_f1_thresh
"
]
=
f1_thresh
def
squad_evaluate
(
examples
,
preds
,
no_answer_probs
=
None
,
no_answer_probability_threshold
=
1.0
):
...
...
@@ -215,18 +219,20 @@ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_
exact
,
f1
=
get_raw_scores
(
examples
,
preds
)
exact_threshold
=
apply_no_ans_threshold
(
exact
,
no_answer_probs
,
qas_id_to_has_answer
,
no_answer_probability_threshold
)
exact_threshold
=
apply_no_ans_threshold
(
exact
,
no_answer_probs
,
qas_id_to_has_answer
,
no_answer_probability_threshold
)
f1_threshold
=
apply_no_ans_threshold
(
f1
,
no_answer_probs
,
qas_id_to_has_answer
,
no_answer_probability_threshold
)
evaluation
=
make_eval_dict
(
exact_threshold
,
f1_threshold
)
if
has_answer_qids
:
has_ans_eval
=
make_eval_dict
(
exact_threshold
,
f1_threshold
,
qid_list
=
has_answer_qids
)
merge_eval
(
evaluation
,
has_ans_eval
,
'
HasAns
'
)
merge_eval
(
evaluation
,
has_ans_eval
,
"
HasAns
"
)
if
no_answer_qids
:
no_ans_eval
=
make_eval_dict
(
exact_threshold
,
f1_threshold
,
qid_list
=
no_answer_qids
)
merge_eval
(
evaluation
,
no_ans_eval
,
'
NoAns
'
)
merge_eval
(
evaluation
,
no_ans_eval
,
"
NoAns
"
)
if
no_answer_probs
:
find_all_best_thresh
(
evaluation
,
preds
,
exact
,
f1
,
no_answer_probs
,
qas_id_to_has_answer
)
...
...
@@ -284,8 +290,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
start_position
=
tok_text
.
find
(
pred_text
)
if
start_position
==
-
1
:
if
verbose_logging
:
logger
.
info
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
logger
.
info
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
return
orig_text
end_position
=
start_position
+
len
(
pred_text
)
-
1
...
...
@@ -294,8 +299,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
verbose_logging
:
logger
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
logger
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
return
orig_text
# We then project the characters in `pred_text` back to `orig_text` using
...
...
@@ -326,7 +330,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
logger
.
info
(
"Couldn't map end position"
)
return
orig_text
output_text
=
orig_text
[
orig_start_position
:
(
orig_end_position
+
1
)]
output_text
=
orig_text
[
orig_start_position
:
(
orig_end_position
+
1
)]
return
output_text
...
...
@@ -393,8 +397,8 @@ def compute_predictions_logits(
unique_id_to_result
[
result
.
unique_id
]
=
result
_PrelimPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"PrelimPrediction"
,
[
"feature_index"
,
"start_index"
,
"end_index"
,
"start_logit"
,
"end_logit"
]
)
"PrelimPrediction"
,
[
"feature_index"
,
"start_index"
,
"end_index"
,
"start_logit"
,
"end_logit"
]
)
all_predictions
=
collections
.
OrderedDict
()
all_nbest_json
=
collections
.
OrderedDict
()
...
...
@@ -447,7 +451,9 @@ def compute_predictions_logits(
start_index
=
start_index
,
end_index
=
end_index
,
start_logit
=
result
.
start_logits
[
start_index
],
end_logit
=
result
.
end_logits
[
end_index
]))
end_logit
=
result
.
end_logits
[
end_index
],
)
)
if
version_2_with_negative
:
prelim_predictions
.
append
(
_PrelimPrediction
(
...
...
@@ -455,14 +461,14 @@ def compute_predictions_logits(
start_index
=
0
,
end_index
=
0
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
prelim_predictions
=
sorted
(
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_logit
+
x
.
end_logit
),
reverse
=
True
)
end_logit
=
null_end_logit
,
)
)
prelim_predictions
=
sorted
(
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_logit
+
x
.
end_logit
),
reverse
=
True
)
_NbestPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"NbestPrediction"
,
[
"text"
,
"start_logit"
,
"end_logit"
])
"NbestPrediction"
,
[
"text"
,
"start_logit"
,
"end_logit"
]
)
seen_predictions
=
{}
nbest
=
[]
...
...
@@ -471,10 +477,10 @@ def compute_predictions_logits(
break
feature
=
features
[
pred
.
feature_index
]
if
pred
.
start_index
>
0
:
# this is a non-null prediction
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:
(
pred
.
end_index
+
1
)]
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:
(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:
(
orig_doc_end
+
1
)]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:
(
orig_doc_end
+
1
)]
tok_text
=
tokenizer
.
convert_tokens_to_string
(
tok_tokens
)
...
...
@@ -498,31 +504,21 @@ def compute_predictions_logits(
final_text
=
""
seen_predictions
[
final_text
]
=
True
nbest
.
append
(
_NbestPrediction
(
text
=
final_text
,
start_logit
=
pred
.
start_logit
,
end_logit
=
pred
.
end_logit
))
nbest
.
append
(
_NbestPrediction
(
text
=
final_text
,
start_logit
=
pred
.
start_logit
,
end_logit
=
pred
.
end_logit
))
# if we didn't include the empty option in the n-best, include it
if
version_2_with_negative
:
if
""
not
in
seen_predictions
:
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
# In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure.
if
len
(
nbest
)
==
1
:
nbest
.
insert
(
0
,
_NbestPrediction
(
text
=
"empty"
,
start_logit
=
0.0
,
end_logit
=
0.0
))
nbest
.
insert
(
0
,
_NbestPrediction
(
text
=
"empty"
,
start_logit
=
0.0
,
end_logit
=
0.0
))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if
not
nbest
:
nbest
.
append
(
_NbestPrediction
(
text
=
"empty"
,
start_logit
=
0.0
,
end_logit
=
0.0
))
nbest
.
append
(
_NbestPrediction
(
text
=
"empty"
,
start_logit
=
0.0
,
end_logit
=
0.0
))
assert
len
(
nbest
)
>=
1
...
...
@@ -551,8 +547,7 @@ def compute_predictions_logits(
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
else
:
# predict "" iff the null score - the score of best non-null > threshold
score_diff
=
score_null
-
best_non_null_entry
.
start_logit
-
(
best_non_null_entry
.
end_logit
)
score_diff
=
score_null
-
best_non_null_entry
.
start_logit
-
(
best_non_null_entry
.
end_logit
)
scores_diff_json
[
example
.
qas_id
]
=
score_diff
if
score_diff
>
null_score_diff_threshold
:
all_predictions
[
example
.
qas_id
]
=
""
...
...
@@ -586,7 +581,7 @@ def compute_predictions_log_probs(
end_n_top
,
version_2_with_negative
,
tokenizer
,
verbose_logging
verbose_logging
,
):
""" XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
...
...
@@ -594,12 +589,12 @@ def compute_predictions_log_probs(
Requires utils_squad_evaluate.py
"""
_PrelimPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"PrelimPrediction"
,
[
"feature_index"
,
"start_index"
,
"end_index"
,
"start_log_prob"
,
"end_log_prob"
])
"PrelimPrediction"
,
[
"feature_index"
,
"start_index"
,
"end_index"
,
"start_log_prob"
,
"end_log_prob"
]
)
_NbestPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"NbestPrediction"
,
[
"text"
,
"start_log_prob"
,
"end_log_prob"
])
"NbestPrediction"
,
[
"text"
,
"start_log_prob"
,
"end_log_prob"
]
)
logger
.
info
(
"Writing predictions to: %s"
,
output_prediction_file
)
# logger.info("Writing nbest to: %s" % (output_nbest_file))
...
...
@@ -663,12 +658,13 @@ def compute_predictions_log_probs(
start_index
=
start_index
,
end_index
=
end_index
,
start_log_prob
=
start_log_prob
,
end_log_prob
=
end_log_prob
))
end_log_prob
=
end_log_prob
,
)
)
prelim_predictions
=
sorted
(
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_log_prob
+
x
.
end_log_prob
),
reverse
=
True
)
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_log_prob
+
x
.
end_log_prob
),
reverse
=
True
)
seen_predictions
=
{}
nbest
=
[]
...
...
@@ -688,10 +684,10 @@ def compute_predictions_log_probs(
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:
(
pred
.
end_index
+
1
)]
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:
(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:
(
orig_doc_end
+
1
)]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:
(
orig_doc_end
+
1
)]
tok_text
=
tokenizer
.
convert_tokens_to_string
(
tok_tokens
)
# Clean whitespace
...
...
@@ -704,8 +700,7 @@ def compute_predictions_log_probs(
else
:
do_lower_case
=
tokenizer
.
do_lowercase_and_remove_accent
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
,
verbose_logging
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
,
verbose_logging
)
if
final_text
in
seen_predictions
:
continue
...
...
@@ -713,17 +708,13 @@ def compute_predictions_log_probs(
seen_predictions
[
final_text
]
=
True
nbest
.
append
(
_NbestPrediction
(
text
=
final_text
,
start_log_prob
=
pred
.
start_log_prob
,
end_log_prob
=
pred
.
end_log_prob
))
_NbestPrediction
(
text
=
final_text
,
start_log_prob
=
pred
.
start_log_prob
,
end_log_prob
=
pred
.
end_log_prob
)
)
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if
not
nbest
:
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_log_prob
=-
1e6
,
end_log_prob
=-
1e6
))
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_log_prob
=-
1e6
,
end_log_prob
=-
1e6
))
total_scores
=
[]
best_non_null_entry
=
None
...
...
transformers/data/processors/__init__.py
View file @
54abc67a
from
.utils
import
InputExample
,
InputFeatures
,
DataProcessor
,
SingleSentenceClassificationProcessor
from
.glue
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.squad
import
squad_convert_examples_to_features
,
SquadFeatures
,
SquadExample
,
SquadV1Processor
,
SquadV2Processor
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from
.glue
import
glue_convert_examples_to_features
,
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
from
.squad
import
SquadExample
,
SquadFeatures
,
SquadV1Processor
,
SquadV2Processor
,
squad_convert_examples_to_features
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
,
SingleSentenceClassificationProcessor
from
.xnli
import
xnli_output_modes
,
xnli_processors
,
xnli_tasks_num_labels
transformers/data/processors/glue.py
View file @
54abc67a
...
...
@@ -18,8 +18,9 @@
import
logging
import
os
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
from
...file_utils
import
is_tf_available
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
if
is_tf_available
():
import
tensorflow
as
tf
...
...
@@ -27,7 +28,9 @@ if is_tf_available():
logger
=
logging
.
getLogger
(
__name__
)
def
glue_convert_examples_to_features
(
examples
,
tokenizer
,
def
glue_convert_examples_to_features
(
examples
,
tokenizer
,
max_length
=
512
,
task
=
None
,
label_list
=
None
,
...
...
@@ -35,7 +38,8 @@ def glue_convert_examples_to_features(examples, tokenizer,
pad_on_left
=
False
,
pad_token
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
mask_padding_with_zero
=
True
,
):
"""
Loads a data file into a list of ``InputFeatures``
...
...
@@ -82,12 +86,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
example
=
processor
.
get_example_from_tensor_dict
(
example
)
example
=
processor
.
tfds_map
(
example
)
inputs
=
tokenizer
.
encode_plus
(
example
.
text_a
,
example
.
text_b
,
add_special_tokens
=
True
,
max_length
=
max_length
,
)
inputs
=
tokenizer
.
encode_plus
(
example
.
text_a
,
example
.
text_b
,
add_special_tokens
=
True
,
max_length
=
max_length
,)
input_ids
,
token_type_ids
=
inputs
[
"input_ids"
],
inputs
[
"token_type_ids"
]
# The mask has 1 for real tokens and 0 for padding tokens. Only real
...
...
@@ -106,8 +105,12 @@ def glue_convert_examples_to_features(examples, tokenizer,
token_type_ids
=
token_type_ids
+
([
pad_token_segment_id
]
*
padding_length
)
assert
len
(
input_ids
)
==
max_length
,
"Error with input length {} vs {}"
.
format
(
len
(
input_ids
),
max_length
)
assert
len
(
attention_mask
)
==
max_length
,
"Error with input length {} vs {}"
.
format
(
len
(
attention_mask
),
max_length
)
assert
len
(
token_type_ids
)
==
max_length
,
"Error with input length {} vs {}"
.
format
(
len
(
token_type_ids
),
max_length
)
assert
len
(
attention_mask
)
==
max_length
,
"Error with input length {} vs {}"
.
format
(
len
(
attention_mask
),
max_length
)
assert
len
(
token_type_ids
)
==
max_length
,
"Error with input length {} vs {}"
.
format
(
len
(
token_type_ids
),
max_length
)
if
output_mode
==
"classification"
:
label
=
label_map
[
example
.
label
]
...
...
@@ -125,28 +128,36 @@ def glue_convert_examples_to_features(examples, tokenizer,
logger
.
info
(
"label: %s (id = %d)"
%
(
example
.
label
,
label
))
features
.
append
(
InputFeatures
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
label
=
label
)
)
InputFeatures
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
label
=
label
)
)
if
is_tf_available
()
and
is_tf_dataset
:
def
gen
():
for
ex
in
features
:
yield
({
'input_ids'
:
ex
.
input_ids
,
'attention_mask'
:
ex
.
attention_mask
,
'token_type_ids'
:
ex
.
token_type_ids
},
ex
.
label
)
return
tf
.
data
.
Dataset
.
from_generator
(
gen
,
({
'input_ids'
:
tf
.
int32
,
'attention_mask'
:
tf
.
int32
,
'token_type_ids'
:
tf
.
int32
},
tf
.
int64
),
({
'input_ids'
:
tf
.
TensorShape
([
None
]),
'attention_mask'
:
tf
.
TensorShape
([
None
]),
'token_type_ids'
:
tf
.
TensorShape
([
None
])},
tf
.
TensorShape
([])))
yield
(
{
"input_ids"
:
ex
.
input_ids
,
"attention_mask"
:
ex
.
attention_mask
,
"token_type_ids"
:
ex
.
token_type_ids
,
},
ex
.
label
,
)
return
tf
.
data
.
Dataset
.
from_generator
(
gen
,
({
"input_ids"
:
tf
.
int32
,
"attention_mask"
:
tf
.
int32
,
"token_type_ids"
:
tf
.
int32
},
tf
.
int64
),
(
{
"input_ids"
:
tf
.
TensorShape
([
None
]),
"attention_mask"
:
tf
.
TensorShape
([
None
]),
"token_type_ids"
:
tf
.
TensorShape
([
None
]),
},
tf
.
TensorShape
([]),
),
)
return
features
...
...
@@ -156,21 +167,21 @@ class MrpcProcessor(DataProcessor):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
return
InputExample
(
tensor_dict
[
"idx"
].
numpy
(),
tensor_dict
[
"sentence1"
].
numpy
().
decode
(
"utf-8"
),
tensor_dict
[
"sentence2"
].
numpy
().
decode
(
"utf-8"
),
str
(
tensor_dict
[
"label"
].
numpy
()),
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {}"
.
format
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)))
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -186,8 +197,7 @@ class MrpcProcessor(DataProcessor):
text_a
=
line
[
3
]
text_b
=
line
[
4
]
label
=
line
[
0
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -196,21 +206,20 @@ class MnliProcessor(DataProcessor):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'premise'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'hypothesis'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
return
InputExample
(
tensor_dict
[
"idx"
].
numpy
(),
tensor_dict
[
"premise"
].
numpy
().
decode
(
"utf-8"
),
tensor_dict
[
"hypothesis"
].
numpy
().
decode
(
"utf-8"
),
str
(
tensor_dict
[
"label"
].
numpy
()),
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -226,8 +235,7 @@ class MnliProcessor(DataProcessor):
text_a
=
line
[
8
]
text_b
=
line
[
9
]
label
=
line
[
-
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -236,9 +244,7 @@ class MnliMismatchedProcessor(MnliProcessor):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_matched"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_matched"
)
class
ColaProcessor
(
DataProcessor
):
...
...
@@ -246,20 +252,20 @@ class ColaProcessor(DataProcessor):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence'
].
numpy
().
decode
(
'utf-8'
),
return
InputExample
(
tensor_dict
[
"idx"
].
numpy
(),
tensor_dict
[
"sentence"
].
numpy
().
decode
(
"utf-8"
),
None
,
str
(
tensor_dict
[
'label'
].
numpy
()))
str
(
tensor_dict
[
"label"
].
numpy
()),
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -272,8 +278,7 @@ class ColaProcessor(DataProcessor):
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
line
[
3
]
label
=
line
[
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
...
...
@@ -282,20 +287,20 @@ class Sst2Processor(DataProcessor):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence'
].
numpy
().
decode
(
'utf-8'
),
return
InputExample
(
tensor_dict
[
"idx"
].
numpy
(),
tensor_dict
[
"sentence"
].
numpy
().
decode
(
"utf-8"
),
None
,
str
(
tensor_dict
[
'label'
].
numpy
()))
str
(
tensor_dict
[
"label"
].
numpy
()),
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -310,8 +315,7 @@ class Sst2Processor(DataProcessor):
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
line
[
0
]
label
=
line
[
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
...
...
@@ -320,20 +324,20 @@ class StsbProcessor(DataProcessor):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
return
InputExample
(
tensor_dict
[
"idx"
].
numpy
(),
tensor_dict
[
"sentence1"
].
numpy
().
decode
(
"utf-8"
),
tensor_dict
[
"sentence2"
].
numpy
().
decode
(
"utf-8"
),
str
(
tensor_dict
[
"label"
].
numpy
()),
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -349,8 +353,7 @@ class StsbProcessor(DataProcessor):
text_a
=
line
[
7
]
text_b
=
line
[
8
]
label
=
line
[
-
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -359,20 +362,20 @@ class QqpProcessor(DataProcessor):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'question1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'question2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
return
InputExample
(
tensor_dict
[
"idx"
].
numpy
(),
tensor_dict
[
"question1"
].
numpy
().
decode
(
"utf-8"
),
tensor_dict
[
"question2"
].
numpy
().
decode
(
"utf-8"
),
str
(
tensor_dict
[
"label"
].
numpy
()),
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -391,8 +394,7 @@ class QqpProcessor(DataProcessor):
label
=
line
[
5
]
except
IndexError
:
continue
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -401,21 +403,20 @@ class QnliProcessor(DataProcessor):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'question'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
return
InputExample
(
tensor_dict
[
"idx"
].
numpy
(),
tensor_dict
[
"question"
].
numpy
().
decode
(
"utf-8"
),
tensor_dict
[
"sentence"
].
numpy
().
decode
(
"utf-8"
),
str
(
tensor_dict
[
"label"
].
numpy
()),
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev_matched"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev_matched"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -431,8 +432,7 @@ class QnliProcessor(DataProcessor):
text_a
=
line
[
1
]
text_b
=
line
[
2
]
label
=
line
[
-
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -441,20 +441,20 @@ class RteProcessor(DataProcessor):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
return
InputExample
(
tensor_dict
[
"idx"
].
numpy
(),
tensor_dict
[
"sentence1"
].
numpy
().
decode
(
"utf-8"
),
tensor_dict
[
"sentence2"
].
numpy
().
decode
(
"utf-8"
),
str
(
tensor_dict
[
"label"
].
numpy
()),
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -470,8 +470,7 @@ class RteProcessor(DataProcessor):
text_a
=
line
[
1
]
text_b
=
line
[
2
]
label
=
line
[
-
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -480,20 +479,20 @@ class WnliProcessor(DataProcessor):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
InputExample
(
tensor_dict
[
'idx'
].
numpy
(),
tensor_dict
[
'sentence1'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'sentence2'
].
numpy
().
decode
(
'utf-8'
),
str
(
tensor_dict
[
'label'
].
numpy
()))
return
InputExample
(
tensor_dict
[
"idx"
].
numpy
(),
tensor_dict
[
"sentence1"
].
numpy
().
decode
(
"utf-8"
),
tensor_dict
[
"sentence2"
].
numpy
().
decode
(
"utf-8"
),
str
(
tensor_dict
[
"label"
].
numpy
()),
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -509,10 +508,10 @@ class WnliProcessor(DataProcessor):
text_a
=
line
[
1
]
text_b
=
line
[
2
]
label
=
line
[
-
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
glue_tasks_num_labels
=
{
"cola"
:
2
,
"mnli"
:
3
,
...
...
transformers/data/processors/squad.py
View file @
54abc67a
from
tqdm
import
tqdm
import
collections
import
json
import
logging
import
os
import
json
import
numpy
as
np
from
multiprocessing
import
Pool
from
multiprocessing
import
cpu_count
from
functools
import
partial
from
multiprocessing
import
Pool
,
cpu_count
import
numpy
as
np
from
tqdm
import
tqdm
from
...tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
from
...file_utils
import
is_tf_available
,
is_torch_available
from
...tokenization_bert
import
whitespace_tokenize
from
.utils
import
DataProcessor
if
is_torch_available
():
import
torch
...
...
@@ -82,8 +82,8 @@ def _is_whitespace(c):
return
True
return
False
def
squad_convert_example_to_features
(
example
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
):
def
squad_convert_example_to_features
(
example
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
):
features
=
[]
if
is_training
and
not
example
.
is_impossible
:
# Get start and end position
...
...
@@ -91,7 +91,7 @@ def squad_convert_example_to_features(example, max_seq_length,
end_position
=
example
.
end_position
# If the answer cannot be found in the text, then skip this example.
actual_text
=
" "
.
join
(
example
.
doc_tokens
[
start_position
:
(
end_position
+
1
)])
actual_text
=
" "
.
join
(
example
.
doc_tokens
[
start_position
:
(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
whitespace_tokenize
(
example
.
answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
...
...
@@ -121,8 +121,11 @@ def squad_convert_example_to_features(example, max_seq_length,
spans
=
[]
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
sequence_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
+
1
\
if
'roberta'
in
str
(
type
(
tokenizer
))
else
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
sequence_added_tokens
=
(
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
+
1
if
"roberta"
in
str
(
type
(
tokenizer
))
else
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
)
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
span_doc_tokens
=
all_doc_tokens
...
...
@@ -135,16 +138,18 @@ def squad_convert_example_to_features(example, max_seq_length,
return_overflowing_tokens
=
True
,
pad_to_max_length
=
True
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
truncation_strategy
=
'
only_second
'
if
tokenizer
.
padding_side
==
"right"
else
'
only_first
'
truncation_strategy
=
"
only_second
"
if
tokenizer
.
padding_side
==
"right"
else
"
only_first
"
,
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
)
if
tokenizer
.
pad_token_id
in
encoded_dict
[
'
input_ids
'
]:
non_padded_ids
=
encoded_dict
[
'
input_ids
'
][:
encoded_dict
[
'
input_ids
'
].
index
(
tokenizer
.
pad_token_id
)]
if
tokenizer
.
pad_token_id
in
encoded_dict
[
"
input_ids
"
]:
non_padded_ids
=
encoded_dict
[
"
input_ids
"
][:
encoded_dict
[
"
input_ids
"
].
index
(
tokenizer
.
pad_token_id
)]
else
:
non_padded_ids
=
encoded_dict
[
'
input_ids
'
]
non_padded_ids
=
encoded_dict
[
"
input_ids
"
]
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
...
...
@@ -170,17 +175,20 @@ def squad_convert_example_to_features(example, max_seq_length,
for
doc_span_index
in
range
(
len
(
spans
)):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
index
=
j
if
tokenizer
.
padding_side
==
"left"
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
index
=
(
j
if
tokenizer
.
padding_side
==
"left"
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
)
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
for
span
in
spans
:
# Identify the position of the CLS token
cls_index
=
span
[
'
input_ids
'
].
index
(
tokenizer
.
cls_token_id
)
cls_index
=
span
[
"
input_ids
"
].
index
(
tokenizer
.
cls_token_id
)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask
=
np
.
array
(
span
[
'
token_type_ids
'
])
p_mask
=
np
.
array
(
span
[
"
token_type_ids
"
])
p_mask
=
np
.
minimum
(
p_mask
,
1
)
...
...
@@ -219,31 +227,34 @@ def squad_convert_example_to_features(example, max_seq_length,
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
features
.
append
(
SquadFeatures
(
span
[
'input_ids'
],
span
[
'attention_mask'
],
span
[
'token_type_ids'
],
features
.
append
(
SquadFeatures
(
span
[
"input_ids"
],
span
[
"attention_mask"
],
span
[
"token_type_ids"
],
cls_index
,
p_mask
.
tolist
(),
example_index
=
0
,
# Can not set unique_id and example_index here. They will be set after multiple processing.
unique_id
=
0
,
paragraph_len
=
span
[
'
paragraph_len
'
],
paragraph_len
=
span
[
"
paragraph_len
"
],
token_is_max_context
=
span
[
"token_is_max_context"
],
tokens
=
span
[
"tokens"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
],
start_position
=
start_position
,
end_position
=
end_position
))
end_position
=
end_position
,
)
)
return
features
def
squad_convert_example_to_features_init
(
tokenizer_for_convert
):
global
tokenizer
tokenizer
=
tokenizer_for_convert
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
return_dataset
=
False
,
threads
=
1
):
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
return_dataset
=
False
,
threads
=
1
):
"""
Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...
...
@@ -283,13 +294,24 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
features
=
[]
threads
=
min
(
threads
,
cpu_count
())
with
Pool
(
threads
,
initializer
=
squad_convert_example_to_features_init
,
initargs
=
(
tokenizer
,))
as
p
:
annotate_
=
partial
(
squad_convert_example_to_features
,
max_seq_length
=
max_seq_length
,
doc_stride
=
doc_stride
,
max_query_length
=
max_query_length
,
is_training
=
is_training
)
features
=
list
(
tqdm
(
p
.
imap
(
annotate_
,
examples
,
chunksize
=
32
),
total
=
len
(
examples
),
desc
=
'convert squad examples to features'
))
annotate_
=
partial
(
squad_convert_example_to_features
,
max_seq_length
=
max_seq_length
,
doc_stride
=
doc_stride
,
max_query_length
=
max_query_length
,
is_training
=
is_training
,
)
features
=
list
(
tqdm
(
p
.
imap
(
annotate_
,
examples
,
chunksize
=
32
),
total
=
len
(
examples
),
desc
=
"convert squad examples to features"
,
)
)
new_features
=
[]
unique_id
=
1000000000
example_index
=
0
for
example_features
in
tqdm
(
features
,
total
=
len
(
features
),
desc
=
'
add example index and unique id
'
):
for
example_features
in
tqdm
(
features
,
total
=
len
(
features
),
desc
=
"
add example index and unique id
"
):
if
not
example_features
:
continue
for
example_feature
in
example_features
:
...
...
@@ -300,7 +322,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
example_index
+=
1
features
=
new_features
del
new_features
if
return_dataset
==
'
pt
'
:
if
return_dataset
==
"
pt
"
:
if
not
is_torch_available
():
raise
ImportError
(
"Pytorch must be installed to return a pytorch dataset."
)
...
...
@@ -341,12 +363,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
"input_ids"
:
ex
.
input_ids
,
"attention_mask"
:
ex
.
attention_mask
,
"token_type_ids"
:
ex
.
token_type_ids
,
},
{
},
{
"start_position"
:
ex
.
start_position
,
"end_position"
:
ex
.
end_position
,
"cls_index"
:
ex
.
cls_index
,
"p_mask"
:
ex
.
p_mask
,
}
}
,
)
return
tf
.
data
.
Dataset
.
from_generator
(
...
...
transformers/data/processors/utils.py
View file @
54abc67a
...
...
@@ -14,16 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
csv
import
sys
import
copy
import
csv
import
json
import
logging
import
sys
from
...file_utils
import
is_tf_available
,
is_torch_available
logger
=
logging
.
getLogger
(
__name__
)
class
InputExample
(
object
):
"""
A single training/test example for simple sequence classification.
...
...
@@ -37,6 +39,7 @@ class InputExample(object):
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
self
.
guid
=
guid
self
.
text_a
=
text_a
...
...
@@ -99,14 +102,15 @@ class DataProcessor(object):
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'
utf-8
'
)
for
cell
in
line
)
line
=
list
(
unicode
(
cell
,
"
utf-8
"
)
for
cell
in
line
)
# noqa: F821
lines
.
append
(
line
)
return
lines
class
SingleSentenceClassificationProcessor
(
DataProcessor
):
""" Generic processor for a single sentence classification data set."""
def
__init__
(
self
,
labels
=
None
,
examples
=
None
,
mode
=
'classification'
,
verbose
=
False
):
def
__init__
(
self
,
labels
=
None
,
examples
=
None
,
mode
=
"classification"
,
verbose
=
False
):
self
.
labels
=
[]
if
labels
is
None
else
labels
self
.
examples
=
[]
if
examples
is
None
else
examples
self
.
mode
=
mode
...
...
@@ -117,22 +121,24 @@ class SingleSentenceClassificationProcessor(DataProcessor):
def
__getitem__
(
self
,
idx
):
if
isinstance
(
idx
,
slice
):
return
SingleSentenceClassificationProcessor
(
labels
=
self
.
labels
,
examples
=
self
.
examples
[
idx
])
return
SingleSentenceClassificationProcessor
(
labels
=
self
.
labels
,
examples
=
self
.
examples
[
idx
])
return
self
.
examples
[
idx
]
@
classmethod
def
create_from_csv
(
cls
,
file_name
,
split_name
=
''
,
column_label
=
0
,
column_text
=
1
,
column_id
=
None
,
skip_first_row
=
False
,
**
kwargs
):
def
create_from_csv
(
cls
,
file_name
,
split_name
=
""
,
column_label
=
0
,
column_text
=
1
,
column_id
=
None
,
skip_first_row
=
False
,
**
kwargs
):
processor
=
cls
(
**
kwargs
)
processor
.
add_examples_from_csv
(
file_name
,
processor
.
add_examples_from_csv
(
file_name
,
split_name
=
split_name
,
column_label
=
column_label
,
column_text
=
column_text
,
column_id
=
column_id
,
skip_first_row
=
skip_first_row
,
overwrite_labels
=
True
,
overwrite_examples
=
True
)
overwrite_examples
=
True
,
)
return
processor
@
classmethod
...
...
@@ -141,8 +147,17 @@ class SingleSentenceClassificationProcessor(DataProcessor):
processor
.
add_examples
(
texts_or_text_and_labels
,
labels
=
labels
)
return
processor
def
add_examples_from_csv
(
self
,
file_name
,
split_name
=
''
,
column_label
=
0
,
column_text
=
1
,
column_id
=
None
,
skip_first_row
=
False
,
overwrite_labels
=
False
,
overwrite_examples
=
False
):
def
add_examples_from_csv
(
self
,
file_name
,
split_name
=
""
,
column_label
=
0
,
column_text
=
1
,
column_id
=
None
,
skip_first_row
=
False
,
overwrite_labels
=
False
,
overwrite_examples
=
False
,
):
lines
=
self
.
_read_tsv
(
file_name
)
if
skip_first_row
:
lines
=
lines
[
1
:]
...
...
@@ -158,10 +173,13 @@ class SingleSentenceClassificationProcessor(DataProcessor):
guid
=
"%s-%s"
%
(
split_name
,
i
)
if
split_name
else
"%s"
%
i
ids
.
append
(
guid
)
return
self
.
add_examples
(
texts
,
labels
,
ids
,
overwrite_labels
=
overwrite_labels
,
overwrite_examples
=
overwrite_examples
)
return
self
.
add_examples
(
texts
,
labels
,
ids
,
overwrite_labels
=
overwrite_labels
,
overwrite_examples
=
overwrite_examples
)
def
add_examples
(
self
,
texts_or_text_and_labels
,
labels
=
None
,
ids
=
None
,
overwrite_labels
=
False
,
overwrite_examples
=
False
):
def
add_examples
(
self
,
texts_or_text_and_labels
,
labels
=
None
,
ids
=
None
,
overwrite_labels
=
False
,
overwrite_examples
=
False
):
assert
labels
is
None
or
len
(
texts_or_text_and_labels
)
==
len
(
labels
)
assert
ids
is
None
or
len
(
texts_or_text_and_labels
)
==
len
(
ids
)
if
ids
is
None
:
...
...
@@ -192,13 +210,15 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return
self
.
examples
def
get_features
(
self
,
def
get_features
(
self
,
tokenizer
,
max_length
=
None
,
pad_on_left
=
False
,
pad_token
=
0
,
mask_padding_with_zero
=
True
,
return_tensors
=
None
):
return_tensors
=
None
,
):
"""
Convert examples in a list of ``InputFeatures``
...
...
@@ -231,9 +251,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
logger
.
info
(
"Tokenizing example %d"
,
ex_index
)
input_ids
=
tokenizer
.
encode
(
example
.
text_a
,
add_special_tokens
=
True
,
max_length
=
min
(
max_length
,
tokenizer
.
max_len
),
example
.
text_a
,
add_special_tokens
=
True
,
max_length
=
min
(
max_length
,
tokenizer
.
max_len
),
)
all_input_ids
.
append
(
input_ids
)
...
...
@@ -256,8 +274,12 @@ class SingleSentenceClassificationProcessor(DataProcessor):
input_ids
=
input_ids
+
([
pad_token
]
*
padding_length
)
attention_mask
=
attention_mask
+
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
assert
len
(
input_ids
)
==
batch_length
,
"Error with input length {} vs {}"
.
format
(
len
(
input_ids
),
batch_length
)
assert
len
(
attention_mask
)
==
batch_length
,
"Error with input length {} vs {}"
.
format
(
len
(
attention_mask
),
batch_length
)
assert
len
(
input_ids
)
==
batch_length
,
"Error with input length {} vs {}"
.
format
(
len
(
input_ids
),
batch_length
)
assert
len
(
attention_mask
)
==
batch_length
,
"Error with input length {} vs {}"
.
format
(
len
(
attention_mask
),
batch_length
)
if
self
.
mode
==
"classification"
:
label
=
label_map
[
example
.
label
]
...
...
@@ -273,36 +295,31 @@ class SingleSentenceClassificationProcessor(DataProcessor):
logger
.
info
(
"attention_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
attention_mask
]))
logger
.
info
(
"label: %s (id = %d)"
%
(
example
.
label
,
label
))
features
.
append
(
InputFeatures
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
label
=
label
))
features
.
append
(
InputFeatures
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
label
=
label
))
if
return_tensors
is
None
:
return
features
elif
return_tensors
==
'
tf
'
:
elif
return_tensors
==
"
tf
"
:
if
not
is_tf_available
():
raise
ImportError
(
"return_tensors set to 'tf' but TensorFlow 2.0 can't be imported"
)
import
tensorflow
as
tf
def
gen
():
for
ex
in
features
:
yield
({
'input_ids'
:
ex
.
input_ids
,
'attention_mask'
:
ex
.
attention_mask
},
ex
.
label
)
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
gen
,
({
'input_ids'
:
tf
.
int32
,
'attention_mask'
:
tf
.
int32
},
tf
.
int64
),
({
'input_ids'
:
tf
.
TensorShape
([
None
]),
'attention_mask'
:
tf
.
TensorShape
([
None
])},
tf
.
TensorShape
([])))
yield
({
"input_ids"
:
ex
.
input_ids
,
"attention_mask"
:
ex
.
attention_mask
},
ex
.
label
)
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
gen
,
({
"input_ids"
:
tf
.
int32
,
"attention_mask"
:
tf
.
int32
},
tf
.
int64
),
({
"input_ids"
:
tf
.
TensorShape
([
None
]),
"attention_mask"
:
tf
.
TensorShape
([
None
])},
tf
.
TensorShape
([])),
)
return
dataset
elif
return_tensors
==
'
pt
'
:
elif
return_tensors
==
"
pt
"
:
if
not
is_torch_available
():
raise
ImportError
(
"return_tensors set to 'pt' but PyTorch can't be imported"
)
import
torch
from
torch.utils.data
import
TensorDataset
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_attention_mask
=
torch
.
tensor
([
f
.
attention_mask
for
f
in
features
],
dtype
=
torch
.
long
)
if
self
.
mode
==
"classification"
:
...
...
transformers/data/processors/xnli.py
View file @
54abc67a
...
...
@@ -22,13 +22,15 @@ import os
from
.utils
import
DataProcessor
,
InputExample
logger
=
logging
.
getLogger
(
__name__
)
class
XnliProcessor
(
DataProcessor
):
"""Processor for the XNLI dataset.
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
def
__init__
(
self
,
language
,
train_language
=
None
):
def
__init__
(
self
,
language
,
train_language
=
None
):
self
.
language
=
language
self
.
train_language
=
train_language
...
...
@@ -40,13 +42,12 @@ class XnliProcessor(DataProcessor):
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
'
train
'
,
i
)
guid
=
"%s-%s"
%
(
"
train
"
,
i
)
text_a
=
line
[
0
]
text_b
=
line
[
1
]
label
=
"contradiction"
if
line
[
2
]
==
"contradictory"
else
line
[
2
]
assert
isinstance
(
text_a
,
str
)
and
isinstance
(
text_b
,
str
)
and
isinstance
(
label
,
str
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
...
...
@@ -59,19 +60,19 @@ class XnliProcessor(DataProcessor):
language
=
line
[
0
]
if
language
!=
self
.
language
:
continue
guid
=
"%s-%s"
%
(
'
test
'
,
i
)
guid
=
"%s-%s"
%
(
"
test
"
,
i
)
text_a
=
line
[
6
]
text_b
=
line
[
7
]
label
=
line
[
1
]
assert
isinstance
(
text_a
,
str
)
and
isinstance
(
text_b
,
str
)
and
isinstance
(
label
,
str
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
xnli_processors
=
{
"xnli"
:
XnliProcessor
,
}
...
...
transformers/file_utils.py
View file @
54abc67a
...
...
@@ -3,35 +3,37 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
sys
import
fnmatch
import
json
import
logging
import
os
import
s
ix
import
s
ys
import
tempfile
import
fnmatch
from
contextlib
import
contextmanager
from
functools
import
partial
,
wraps
from
hashlib
import
sha256
from
io
import
open
import
boto3
import
requests
import
six
from
botocore.config
import
Config
from
botocore.exceptions
import
ClientError
import
requests
from
filelock
import
FileLock
from
tqdm.auto
import
tqdm
from
contextlib
import
contextmanager
from
.
import
__version__
from
filelock
import
FileLock
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
try
:
os
.
environ
.
setdefault
(
'
USE_TORCH
'
,
'
YES
'
)
if
os
.
environ
[
'
USE_TORCH
'
].
upper
()
in
(
'1'
,
'
ON
'
,
'
YES
'
):
os
.
environ
.
setdefault
(
"
USE_TORCH
"
,
"
YES
"
)
if
os
.
environ
[
"
USE_TORCH
"
].
upper
()
in
(
"1"
,
"
ON
"
,
"
YES
"
):
import
torch
_torch_available
=
True
# pylint: disable=invalid-name
logger
.
info
(
"PyTorch version {} available."
.
format
(
torch
.
__version__
))
else
:
...
...
@@ -41,10 +43,11 @@ except ImportError:
_torch_available
=
False
# pylint: disable=invalid-name
try
:
os
.
environ
.
setdefault
(
'
USE_TF
'
,
'
YES
'
)
if
os
.
environ
[
'
USE_TF
'
].
upper
()
in
(
'1'
,
'
ON
'
,
'
YES
'
):
os
.
environ
.
setdefault
(
"
USE_TF
"
,
"
YES
"
)
if
os
.
environ
[
"
USE_TF
"
].
upper
()
in
(
"1"
,
"
ON
"
,
"
YES
"
):
import
tensorflow
as
tf
assert
hasattr
(
tf
,
'__version__'
)
and
int
(
tf
.
__version__
[
0
])
>=
2
assert
hasattr
(
tf
,
"__version__"
)
and
int
(
tf
.
__version__
[
0
])
>=
2
_tf_available
=
True
# pylint: disable=invalid-name
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
else
:
...
...
@@ -55,12 +58,13 @@ except (ImportError, AssertionError):
try
:
from
torch.hub
import
_get_torch_home
torch_cache_home
=
_get_torch_home
()
except
ImportError
:
torch_cache_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
'
TORCH_HOME
'
,
os
.
path
.
join
(
os
.
getenv
(
'XDG_CACHE_HOME'
,
'~/.cache'
),
'torch'
))
)
default_cache_path
=
os
.
path
.
join
(
torch_cache_home
,
'
transformers
'
)
os
.
getenv
(
"
TORCH_HOME
"
,
os
.
path
.
join
(
os
.
getenv
(
"XDG_CACHE_HOME"
,
"~/.cache"
),
"torch"
))
)
default_cache_path
=
os
.
path
.
join
(
torch_cache_home
,
"
transformers
"
)
try
:
from
urllib.parse
import
urlparse
...
...
@@ -69,19 +73,21 @@ except ImportError:
try
:
from
pathlib
import
Path
PYTORCH_PRETRAINED_BERT_CACHE
=
Path
(
os
.
getenv
(
'PYTORCH_TRANSFORMERS_CACHE'
,
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
default_cache_path
)))
os
.
getenv
(
"PYTORCH_TRANSFORMERS_CACHE"
,
os
.
getenv
(
"PYTORCH_PRETRAINED_BERT_CACHE"
,
default_cache_path
))
)
except
(
AttributeError
,
ImportError
):
PYTORCH_PRETRAINED_BERT_CACHE
=
os
.
getenv
(
'PYTORCH_TRANSFORMERS_CACHE'
,
os
.
getenv
(
'
PYTORCH_PRETRAINED_BERT_CACHE
'
,
default_cache_path
)
)
PYTORCH_PRETRAINED_BERT_CACHE
=
os
.
getenv
(
"PYTORCH_TRANSFORMERS_CACHE"
,
os
.
getenv
(
"
PYTORCH_PRETRAINED_BERT_CACHE
"
,
default_cache_path
)
)
PYTORCH_TRANSFORMERS_CACHE
=
PYTORCH_PRETRAINED_BERT_CACHE
# Kept for backward compatibility
TRANSFORMERS_CACHE
=
PYTORCH_PRETRAINED_BERT_CACHE
# Kept for backward compatibility
WEIGHTS_NAME
=
"pytorch_model.bin"
TF2_WEIGHTS_NAME
=
'
tf_model.h5
'
TF_WEIGHTS_NAME
=
'
model.ckpt
'
TF2_WEIGHTS_NAME
=
"
tf_model.h5
"
TF_WEIGHTS_NAME
=
"
model.ckpt
"
CONFIG_NAME
=
"config.json"
MODEL_CARD_NAME
=
"modelcard.json"
...
...
@@ -95,38 +101,48 @@ CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
def
is_torch_available
():
return
_torch_available
def
is_tf_available
():
return
_tf_available
if
not
six
.
PY2
:
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
''
.
join
(
docstr
)
+
fn
.
__doc__
fn
.
__doc__
=
""
.
join
(
docstr
)
+
fn
.
__doc__
return
fn
return
docstring_decorator
def
add_end_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
fn
.
__doc__
+
''
.
join
(
docstr
)
fn
.
__doc__
=
fn
.
__doc__
+
""
.
join
(
docstr
)
return
fn
return
docstring_decorator
else
:
# Not possible to update class docstrings on python2
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
return
fn
return
docstring_decorator
def
add_end_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
return
fn
return
docstring_decorator
def
is_remote_url
(
url_or_filename
):
parsed
=
urlparse
(
url_or_filename
)
return
parsed
.
scheme
in
(
'http'
,
'https'
,
's3'
)
return
parsed
.
scheme
in
(
"http"
,
"https"
,
"s3"
)
def
hf_bucket_url
(
identifier
,
postfix
=
None
,
cdn
=
False
):
endpoint
=
CLOUDFRONT_DISTRIB_PREFIX
if
cdn
else
S3_BUCKET_PREFIX
...
...
@@ -145,17 +161,17 @@ def url_to_filename(url, etag=None):
so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes
=
url
.
encode
(
'
utf-8
'
)
url_bytes
=
url
.
encode
(
"
utf-8
"
)
url_hash
=
sha256
(
url_bytes
)
filename
=
url_hash
.
hexdigest
()
if
etag
:
etag_bytes
=
etag
.
encode
(
'
utf-8
'
)
etag_bytes
=
etag
.
encode
(
"
utf-8
"
)
etag_hash
=
sha256
(
etag_bytes
)
filename
+=
'.'
+
etag_hash
.
hexdigest
()
filename
+=
"."
+
etag_hash
.
hexdigest
()
if
url
.
endswith
(
'
.h5
'
):
filename
+=
'
.h5
'
if
url
.
endswith
(
"
.h5
"
):
filename
+=
"
.h5
"
return
filename
...
...
@@ -174,19 +190,21 @@ def filename_to_url(filename, cache_dir=None):
if
not
os
.
path
.
exists
(
cache_path
):
raise
EnvironmentError
(
"file {} not found"
.
format
(
cache_path
))
meta_path
=
cache_path
+
'
.json
'
meta_path
=
cache_path
+
"
.json
"
if
not
os
.
path
.
exists
(
meta_path
):
raise
EnvironmentError
(
"file {} not found"
.
format
(
meta_path
))
with
open
(
meta_path
,
encoding
=
"utf-8"
)
as
meta_file
:
metadata
=
json
.
load
(
meta_file
)
url
=
metadata
[
'
url
'
]
etag
=
metadata
[
'
etag
'
]
url
=
metadata
[
"
url
"
]
etag
=
metadata
[
"
etag
"
]
return
url
,
etag
def
cached_path
(
url_or_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
resume_download
=
False
,
user_agent
=
None
):
def
cached_path
(
url_or_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
resume_download
=
False
,
user_agent
=
None
):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
...
...
@@ -207,13 +225,18 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
if
is_remote_url
(
url_or_filename
):
# URL, so get it from the cache (downloading if necessary)
return
get_from_cache
(
url_or_filename
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
user_agent
=
user_agent
)
return
get_from_cache
(
url_or_filename
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
user_agent
=
user_agent
,
)
elif
os
.
path
.
exists
(
url_or_filename
):
# File, and it exists.
return
url_or_filename
elif
urlparse
(
url_or_filename
).
scheme
==
''
:
elif
urlparse
(
url_or_filename
).
scheme
==
""
:
# File, but it doesn't exist.
raise
EnvironmentError
(
"file {} not found"
.
format
(
url_or_filename
))
else
:
...
...
@@ -273,23 +296,25 @@ def s3_get(url, temp_file, proxies=None):
def
http_get
(
url
,
temp_file
,
proxies
=
None
,
resume_size
=
0
,
user_agent
=
None
):
ua
=
"transformers/{}; python/{}"
.
format
(
__version__
,
sys
.
version
.
split
()[
0
])
if
isinstance
(
user_agent
,
dict
):
ua
+=
"; "
+
"; "
.
join
(
"{}/{}"
.
format
(
k
,
v
)
for
k
,
v
in
user_agent
.
items
()
)
ua
+=
"; "
+
"; "
.
join
(
"{}/{}"
.
format
(
k
,
v
)
for
k
,
v
in
user_agent
.
items
())
elif
isinstance
(
user_agent
,
six
.
string_types
):
ua
+=
"; "
+
user_agent
headers
=
{
"user-agent"
:
ua
}
ua
+=
"; "
+
user_agent
headers
=
{
"user-agent"
:
ua
}
if
resume_size
>
0
:
headers
[
'
Range
'
]
=
'
bytes=%d-
'
%
(
resume_size
,)
headers
[
"
Range
"
]
=
"
bytes=%d-
"
%
(
resume_size
,)
response
=
requests
.
get
(
url
,
stream
=
True
,
proxies
=
proxies
,
headers
=
headers
)
if
response
.
status_code
==
416
:
# Range not satisfiable
return
content_length
=
response
.
headers
.
get
(
'
Content-Length
'
)
content_length
=
response
.
headers
.
get
(
"
Content-Length
"
)
total
=
resume_size
+
int
(
content_length
)
if
content_length
is
not
None
else
None
progress
=
tqdm
(
unit
=
"B"
,
unit_scale
=
True
,
total
=
total
,
initial
=
resume_size
,
desc
=
"Downloading"
,
disable
=
bool
(
logger
.
level
<=
logging
.
INFO
))
progress
=
tqdm
(
unit
=
"B"
,
unit_scale
=
True
,
total
=
total
,
initial
=
resume_size
,
desc
=
"Downloading"
,
disable
=
bool
(
logger
.
level
<=
logging
.
INFO
),
)
for
chunk
in
response
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
# filter out keep-alive new chunks
progress
.
update
(
len
(
chunk
))
...
...
@@ -297,7 +322,9 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
progress
.
close
()
def
get_from_cache
(
url
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
etag_timeout
=
10
,
resume_download
=
False
,
user_agent
=
None
):
def
get_from_cache
(
url
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
etag_timeout
=
10
,
resume_download
=
False
,
user_agent
=
None
):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
...
...
@@ -326,7 +353,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
etag
=
None
if
sys
.
version_info
[
0
]
==
2
and
etag
is
not
None
:
etag
=
etag
.
decode
(
'
utf-8
'
)
etag
=
etag
.
decode
(
"
utf-8
"
)
filename
=
url_to_filename
(
url
,
etag
)
# get cache path to put the file
...
...
@@ -337,22 +364,24 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
if
not
os
.
path
.
exists
(
cache_path
)
and
etag
is
None
:
matching_files
=
[
file
for
file
in
fnmatch
.
filter
(
os
.
listdir
(
cache_dir
),
filename
+
'
.*
'
)
if
not
file
.
endswith
(
'
.json
'
)
and
not
file
.
endswith
(
'
.lock
'
)
for
file
in
fnmatch
.
filter
(
os
.
listdir
(
cache_dir
),
filename
+
"
.*
"
)
if
not
file
.
endswith
(
"
.json
"
)
and
not
file
.
endswith
(
"
.lock
"
)
]
if
matching_files
:
cache_path
=
os
.
path
.
join
(
cache_dir
,
matching_files
[
-
1
])
# Prevent parallel downloads of the same file with a lock.
lock_path
=
cache_path
+
'
.lock
'
lock_path
=
cache_path
+
"
.lock
"
with
FileLock
(
lock_path
):
if
resume_download
:
incomplete_path
=
cache_path
+
'.incomplete'
incomplete_path
=
cache_path
+
".incomplete"
@
contextmanager
def
_resumable_file_manager
():
with
open
(
incomplete_path
,
'
a+b
'
)
as
f
:
with
open
(
incomplete_path
,
"
a+b
"
)
as
f
:
yield
f
temp_file_manager
=
_resumable_file_manager
if
os
.
path
.
exists
(
incomplete_path
):
resume_size
=
os
.
stat
(
incomplete_path
).
st_size
...
...
@@ -366,7 +395,9 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with
temp_file_manager
()
as
temp_file
:
logger
.
info
(
"%s not found in cache or force_download set to True, downloading to %s"
,
url
,
temp_file
.
name
)
logger
.
info
(
"%s not found in cache or force_download set to True, downloading to %s"
,
url
,
temp_file
.
name
)
# GET file object
if
url
.
startswith
(
"s3://"
):
...
...
@@ -383,12 +414,12 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
os
.
rename
(
temp_file
.
name
,
cache_path
)
logger
.
info
(
"creating metadata file for %s"
,
cache_path
)
meta
=
{
'
url
'
:
url
,
'
etag
'
:
etag
}
meta_path
=
cache_path
+
'
.json
'
with
open
(
meta_path
,
'w'
)
as
meta_file
:
meta
=
{
"
url
"
:
url
,
"
etag
"
:
etag
}
meta_path
=
cache_path
+
"
.json
"
with
open
(
meta_path
,
"w"
)
as
meta_file
:
output_string
=
json
.
dumps
(
meta
)
if
sys
.
version_info
[
0
]
==
2
and
isinstance
(
output_string
,
str
):
output_string
=
unicode
(
output_string
,
'
utf-8
'
)
#
The beauty of python 2
output_string
=
unicode
(
output_string
,
"
utf-8
"
)
#
noqa: F821
meta_file
.
write
(
output_string
)
return
cache_path
transformers/hf_api.py
View file @
54abc67a
...
...
@@ -14,16 +14,19 @@
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
io
import
os
from
os.path
import
expanduser
from
typing
import
List
import
requests
import
six
from
requests.exceptions
import
HTTPError
from
tqdm
import
tqdm
ENDPOINT
=
"https://huggingface.co"
class
S3Obj
:
def
__init__
(
self
,
...
...
@@ -78,8 +81,7 @@ class HfApi:
return
d
[
"token"
]
def
whoami
(
self
,
token
,
# type: str
self
,
token
,
# type: str
):
# type: (...) -> str
"""
...
...
@@ -92,7 +94,7 @@ class HfApi:
return
d
[
"user"
]
def
logout
(
self
,
token
):
# type: (...) ->
void
# type: (...) ->
None
"""
Call HF API to log out.
"""
...
...
@@ -106,11 +108,7 @@ class HfApi:
Call HF API to get a presigned url to upload `filename` to S3.
"""
path
=
"{}/api/presign"
.
format
(
self
.
endpoint
)
r
=
requests
.
post
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
},
)
r
=
requests
.
post
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
},)
r
.
raise_for_status
()
d
=
r
.
json
()
return
PresignedUrl
(
**
d
)
...
...
@@ -133,15 +131,12 @@ class HfApi:
pf
=
TqdmProgressFileReader
(
f
)
data
=
f
if
pf
.
total_size
>
0
else
""
r
=
requests
.
put
(
urls
.
write
,
data
=
data
,
headers
=
{
"content-type"
:
urls
.
type
,
})
r
=
requests
.
put
(
urls
.
write
,
data
=
data
,
headers
=
{
"content-type"
:
urls
.
type
})
r
.
raise_for_status
()
pf
.
close
()
return
urls
.
access
def
list_objs
(
self
,
token
):
# type: (...) -> List[S3Obj]
def
list_objs
(
self
,
token
)
->
List
[
S3Obj
]:
"""
Call HF API to list all stored files for user.
"""
...
...
@@ -152,7 +147,6 @@ class HfApi:
return
[
S3Obj
(
**
x
)
for
x
in
d
]
class
TqdmProgressFileReader
:
"""
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`)
...
...
@@ -161,10 +155,8 @@ class TqdmProgressFileReader:
see github.com/huggingface/transformers/pull/2078#discussion_r354739608
for implementation details.
"""
def
__init__
(
self
,
f
# type: io.BufferedReader
):
def
__init__
(
self
,
f
:
io
.
BufferedReader
):
self
.
f
=
f
self
.
total_size
=
os
.
fstat
(
f
.
fileno
()).
st_size
# type: int
self
.
pbar
=
tqdm
(
total
=
self
.
total_size
,
leave
=
False
)
...
...
@@ -182,7 +174,6 @@ class TqdmProgressFileReader:
self
.
pbar
.
close
()
class
HfFolder
:
path_token
=
expanduser
(
"~/.huggingface/token"
)
...
...
@@ -201,7 +192,7 @@ class HfFolder:
if
e
.
errno
!=
os
.
errno
.
EEXIST
:
raise
e
pass
with
open
(
cls
.
path_token
,
'
w+
'
)
as
f
:
with
open
(
cls
.
path_token
,
"
w+
"
)
as
f
:
f
.
write
(
token
)
@
classmethod
...
...
@@ -210,12 +201,10 @@ class HfFolder:
Get token or None if not existent.
"""
try
:
with
open
(
cls
.
path_token
,
'r'
)
as
f
:
with
open
(
cls
.
path_token
,
"r"
)
as
f
:
return
f
.
read
()
except
:
# this is too wide. When Py2 is dead use:
# `except FileNotFoundError:` instead
return
None
except
FileNotFoundError
:
pass
@
classmethod
def
delete_token
(
cls
):
...
...
@@ -225,5 +214,5 @@ class HfFolder:
"""
try
:
os
.
remove
(
cls
.
path_token
)
except
:
return
except
FileNotFoundError
:
pass
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment