Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8c276b9c
Unverified
Commit
8c276b9c
authored
Nov 27, 2019
by
Stefan Schweter
Committed by
GitHub
Nov 27, 2019
Browse files
Merge branch 'master' into distilbert-german
parents
da06afaf
3c28a2da
Changes
54
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1978 additions
and
29 deletions
+1978
-29
transformers/configuration_albert.py
transformers/configuration_albert.py
+100
-0
transformers/configuration_auto.py
transformers/configuration_auto.py
+3
-0
transformers/configuration_distilbert.py
transformers/configuration_distilbert.py
+1
-0
transformers/configuration_utils.py
transformers/configuration_utils.py
+6
-1
transformers/convert_albert_original_tf_checkpoint_to_pytorch.py
...rmers/convert_albert_original_tf_checkpoint_to_pytorch.py
+67
-0
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+9
-4
transformers/data/__init__.py
transformers/data/__init__.py
+2
-1
transformers/data/metrics/__init__.py
transformers/data/metrics/__init__.py
+8
-0
transformers/data/processors/__init__.py
transformers/data/processors/__init__.py
+1
-1
transformers/data/processors/xnli.py
transformers/data/processors/xnli.py
+85
-0
transformers/file_utils.py
transformers/file_utils.py
+36
-11
transformers/modeling_albert.py
transformers/modeling_albert.py
+764
-0
transformers/modeling_auto.py
transformers/modeling_auto.py
+21
-0
transformers/modeling_bert.py
transformers/modeling_bert.py
+8
-8
transformers/modeling_camembert.py
transformers/modeling_camembert.py
+37
-1
transformers/modeling_distilbert.py
transformers/modeling_distilbert.py
+1
-0
transformers/modeling_tf_albert.py
transformers/modeling_tf_albert.py
+799
-0
transformers/modeling_tf_auto.py
transformers/modeling_tf_auto.py
+12
-0
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+7
-1
transformers/modeling_utils.py
transformers/modeling_utils.py
+11
-1
No files found.
transformers/configuration_albert.py
0 → 100644
View file @
8c276b9c
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ALBERT model configuration """
from
.configuration_utils
import
PretrainedConfig
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'albert-base-v1'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json"
,
'albert-large-v1'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json"
,
'albert-xlarge-v1'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json"
,
'albert-xxlarge-v1'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json"
,
'albert-base-v2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json"
,
'albert-large-v2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json"
,
'albert-xlarge-v2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json"
,
'albert-xxlarge-v2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json"
,
}
class
AlbertConfig
(
PretrainedConfig
):
"""Configuration for `AlbertModel`.
The default settings match the configuration of model `albert_xxlarge`.
"""
pretrained_config_archive_map
=
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
vocab_size_or_config_json_file
=
30000
,
embedding_size
=
128
,
hidden_size
=
4096
,
num_hidden_layers
=
12
,
num_hidden_groups
=
1
,
num_attention_heads
=
64
,
intermediate_size
=
16384
,
inner_group_num
=
1
,
hidden_act
=
"gelu_new"
,
hidden_dropout_prob
=
0
,
attention_probs_dropout_prob
=
0
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
**
kwargs
):
"""Constructs AlbertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `AlbertModel`.
embedding_size: size of voc embeddings.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_hidden_groups: Number of group for the hidden layers, parameters in
the same group are shared.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
inner_group_num: int, number of inner repetition of attention and ffn.
down_scale_factor: float, the scale to apply
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`AlbertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
super
(
AlbertConfig
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size_or_config_json_file
self
.
embedding_size
=
embedding_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_groups
=
num_hidden_groups
self
.
num_attention_heads
=
num_attention_heads
self
.
inner_group_num
=
inner_group_num
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
\ No newline at end of file
transformers/configuration_auto.py
View file @
8c276b9c
...
@@ -95,6 +95,9 @@ class AutoConfig(object):
...
@@ -95,6 +95,9 @@ class AutoConfig(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
...
transformers/configuration_distilbert.py
View file @
8c276b9c
...
@@ -29,6 +29,7 @@ DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
...
@@ -29,6 +29,7 @@ DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'distilbert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json"
,
'distilbert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json"
,
'distilbert-base-uncased-distilled-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json"
,
'distilbert-base-uncased-distilled-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json"
,
'distilbert-base-german-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json"
,
'distilbert-base-german-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json"
,
'distilbert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json"
,
}
}
...
...
transformers/configuration_utils.py
View file @
8c276b9c
...
@@ -94,6 +94,9 @@ class PretrainedConfig(object):
...
@@ -94,6 +94,9 @@ class PretrainedConfig(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -120,6 +123,7 @@ class PretrainedConfig(object):
...
@@ -120,6 +123,7 @@ class PretrainedConfig(object):
"""
"""
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
resume_download
=
kwargs
.
pop
(
'resume_download'
,
False
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
return_unused_kwargs
=
kwargs
.
pop
(
'return_unused_kwargs'
,
False
)
return_unused_kwargs
=
kwargs
.
pop
(
'return_unused_kwargs'
,
False
)
...
@@ -131,7 +135,8 @@ class PretrainedConfig(object):
...
@@ -131,7 +135,8 @@ class PretrainedConfig(object):
config_file
=
pretrained_model_name_or_path
config_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
except
EnvironmentError
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
...
...
transformers/convert_albert_original_tf_checkpoint_to_pytorch.py
0 → 100644
View file @
8c276b9c
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALBERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
torch
from
transformers
import
AlbertConfig
,
AlbertForMaskedLM
,
load_tf_weights_in_albert
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
albert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
config
=
AlbertConfig
.
from_json_file
(
albert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
AlbertForMaskedLM
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_albert
(
model
,
config
,
tf_checkpoint_path
)
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--albert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained ALBERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
albert_config_file
,
args
.
pytorch_dump_path
)
\ No newline at end of file
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
8c276b9c
...
@@ -33,7 +33,8 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
...
@@ -33,7 +33,8 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AlbertConfig
,
TFAlbertForMaskedLM
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
)
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
...
@@ -46,7 +47,8 @@ if is_torch_available():
...
@@ -46,7 +47,8 @@ if is_torch_available():
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
else
:
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
...
@@ -56,7 +58,8 @@ else:
...
@@ -56,7 +58,8 @@ else:
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
=
(
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
...
@@ -65,6 +68,7 @@ else:
...
@@ -65,6 +68,7 @@ else:
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
None
,
None
)
...
@@ -85,7 +89,8 @@ MODEL_CLASSES = {
...
@@ -85,7 +89,8 @@ MODEL_CLASSES = {
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'albert'
:
(
AlbertConfig
,
TFAlbertForMaskedLM
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
)
}
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
...
...
transformers/data/__init__.py
View file @
8c276b9c
from
.processors
import
InputExample
,
InputFeatures
,
DataProcessor
from
.processors
import
InputExample
,
InputFeatures
,
DataProcessor
from
.processors
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.processors
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.processors
import
xnli_output_modes
,
xnli_processors
,
xnli_tasks_num_labels
from
.metrics
import
is_sklearn_available
from
.metrics
import
is_sklearn_available
if
is_sklearn_available
():
if
is_sklearn_available
():
from
.metrics
import
glue_compute_metrics
from
.metrics
import
glue_compute_metrics
,
xnli_compute_metrics
transformers/data/metrics/__init__.py
View file @
8c276b9c
...
@@ -81,3 +81,11 @@ if _has_sklearn:
...
@@ -81,3 +81,11 @@ if _has_sklearn:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
else
:
else
:
raise
KeyError
(
task_name
)
raise
KeyError
(
task_name
)
def
xnli_compute_metrics
(
task_name
,
preds
,
labels
):
assert
len
(
preds
)
==
len
(
labels
)
if
task_name
==
"xnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
else
:
raise
KeyError
(
task_name
)
transformers/data/processors/__init__.py
View file @
8c276b9c
from
.utils
import
InputExample
,
InputFeatures
,
DataProcessor
from
.utils
import
InputExample
,
InputFeatures
,
DataProcessor
from
.glue
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.glue
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.xnli
import
xnli_output_modes
,
xnli_processors
,
xnli_tasks_num_labels
transformers/data/processors/xnli.py
0 → 100644
View file @
8c276b9c
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" XNLI utils (dataset loading and evaluation) """
from
__future__
import
absolute_import
,
division
,
print_function
import
logging
import
os
from
.utils
import
DataProcessor
,
InputExample
logger
=
logging
.
getLogger
(
__name__
)
class
XnliProcessor
(
DataProcessor
):
"""Processor for the XNLI dataset.
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
def
__init__
(
self
,
language
,
train_language
=
None
):
self
.
language
=
language
self
.
train_language
=
train_language
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lg
=
self
.
language
if
self
.
train_language
is
None
else
self
.
train_language
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"XNLI-MT-1.0/multinli/multinli.train.{}.tsv"
.
format
(
lg
)))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
'train'
,
i
)
text_a
=
line
[
0
]
text_b
=
line
[
1
]
label
=
"contradiction"
if
line
[
2
]
==
"contradictory"
else
line
[
2
]
assert
isinstance
(
text_a
,
str
)
and
isinstance
(
text_b
,
str
)
and
isinstance
(
label
,
str
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"XNLI-1.0/xnli.test.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
language
=
line
[
0
]
if
language
!=
self
.
language
:
continue
guid
=
"%s-%s"
%
(
'test'
,
i
)
text_a
=
line
[
6
]
text_b
=
line
[
7
]
label
=
line
[
1
]
assert
isinstance
(
text_a
,
str
)
and
isinstance
(
text_b
,
str
)
and
isinstance
(
label
,
str
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
xnli_processors
=
{
"xnli"
:
XnliProcessor
,
}
xnli_output_modes
=
{
"xnli"
:
"classification"
,
}
xnli_tasks_num_labels
=
{
"xnli"
:
3
,
}
transformers/file_utils.py
View file @
8c276b9c
...
@@ -22,6 +22,7 @@ from botocore.config import Config
...
@@ -22,6 +22,7 @@ from botocore.config import Config
from
botocore.exceptions
import
ClientError
from
botocore.exceptions
import
ClientError
import
requests
import
requests
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
contextlib
import
contextmanager
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -152,7 +153,7 @@ def filename_to_url(filename, cache_dir=None):
...
@@ -152,7 +153,7 @@ def filename_to_url(filename, cache_dir=None):
return
url
,
etag
return
url
,
etag
def
cached_path
(
url_or_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
):
def
cached_path
(
url_or_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
resume_download
=
False
):
"""
"""
Given something that might be a URL (or might be a local path),
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
determine which. If it's a URL, download the file and cache it, and
...
@@ -161,6 +162,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
...
@@ -161,6 +162,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
Args:
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
"""
"""
if
cache_dir
is
None
:
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
cache_dir
=
TRANSFORMERS_CACHE
...
@@ -173,7 +175,9 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
...
@@ -173,7 +175,9 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
if
parsed
.
scheme
in
(
'http'
,
'https'
,
's3'
):
if
parsed
.
scheme
in
(
'http'
,
'https'
,
's3'
):
# URL, so get it from the cache (downloading if necessary)
# URL, so get it from the cache (downloading if necessary)
return
get_from_cache
(
url_or_filename
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
return
get_from_cache
(
url_or_filename
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
elif
os
.
path
.
exists
(
url_or_filename
):
elif
os
.
path
.
exists
(
url_or_filename
):
# File, and it exists.
# File, and it exists.
return
url_or_filename
return
url_or_filename
...
@@ -234,19 +238,22 @@ def s3_get(url, temp_file, proxies=None):
...
@@ -234,19 +238,22 @@ def s3_get(url, temp_file, proxies=None):
s3_resource
.
Bucket
(
bucket_name
).
download_fileobj
(
s3_path
,
temp_file
)
s3_resource
.
Bucket
(
bucket_name
).
download_fileobj
(
s3_path
,
temp_file
)
def
http_get
(
url
,
temp_file
,
proxies
=
None
):
def
http_get
(
url
,
temp_file
,
proxies
=
None
,
resume_size
=
0
):
req
=
requests
.
get
(
url
,
stream
=
True
,
proxies
=
proxies
)
headers
=
{
'Range'
:
'bytes=%d-'
%
(
resume_size
,)}
if
resume_size
>
0
else
None
content_length
=
req
.
headers
.
get
(
'Content-Length'
)
response
=
requests
.
get
(
url
,
stream
=
True
,
proxies
=
proxies
,
headers
=
headers
)
total
=
int
(
content_length
)
if
content_length
is
not
None
else
None
if
response
.
status_code
==
416
:
# Range not satisfiable
progress
=
tqdm
(
unit
=
"B"
,
total
=
total
)
return
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
content_length
=
response
.
headers
.
get
(
'Content-Length'
)
total
=
resume_size
+
int
(
content_length
)
if
content_length
is
not
None
else
None
progress
=
tqdm
(
unit
=
"B"
,
total
=
total
,
initial
=
resume_size
)
for
chunk
in
response
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
# filter out keep-alive new chunks
if
chunk
:
# filter out keep-alive new chunks
progress
.
update
(
len
(
chunk
))
progress
.
update
(
len
(
chunk
))
temp_file
.
write
(
chunk
)
temp_file
.
write
(
chunk
)
progress
.
close
()
progress
.
close
()
def
get_from_cache
(
url
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
etag_timeout
=
10
):
def
get_from_cache
(
url
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
etag_timeout
=
10
,
resume_download
=
False
):
"""
"""
Given a URL, look for the corresponding dataset in the local cache.
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
If it's not there, download it. Then return the path to the cached file.
...
@@ -289,17 +296,35 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
...
@@ -289,17 +296,35 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
if
matching_files
:
if
matching_files
:
cache_path
=
os
.
path
.
join
(
cache_dir
,
matching_files
[
-
1
])
cache_path
=
os
.
path
.
join
(
cache_dir
,
matching_files
[
-
1
])
if
resume_download
:
incomplete_path
=
cache_path
+
'.incomplete'
@
contextmanager
def
_resumable_file_manager
():
with
open
(
incomplete_path
,
'a+b'
)
as
f
:
yield
f
os
.
remove
(
incomplete_path
)
temp_file_manager
=
_resumable_file_manager
if
os
.
path
.
exists
(
incomplete_path
):
resume_size
=
os
.
stat
(
incomplete_path
).
st_size
else
:
resume_size
=
0
else
:
temp_file_manager
=
tempfile
.
NamedTemporaryFile
resume_size
=
0
if
not
os
.
path
.
exists
(
cache_path
)
or
force_download
:
if
not
os
.
path
.
exists
(
cache_path
)
or
force_download
:
# Download to temporary file, then copy to cache dir once finished.
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with
tempfile
.
NamedTemporaryFile
()
as
temp_file
:
with
temp
_
file
_manager
()
as
temp_file
:
logger
.
info
(
"%s not found in cache or force_download set to True, downloading to %s"
,
url
,
temp_file
.
name
)
logger
.
info
(
"%s not found in cache or force_download set to True, downloading to %s"
,
url
,
temp_file
.
name
)
# GET file object
# GET file object
if
url
.
startswith
(
"s3://"
):
if
url
.
startswith
(
"s3://"
):
if
resume_download
:
logger
.
warn
(
'Warning: resumable downloads are not implemented for "s3://" urls'
)
s3_get
(
url
,
temp_file
,
proxies
=
proxies
)
s3_get
(
url
,
temp_file
,
proxies
=
proxies
)
else
:
else
:
http_get
(
url
,
temp_file
,
proxies
=
proxies
)
http_get
(
url
,
temp_file
,
proxies
=
proxies
,
resume_size
=
resume_size
)
# we are copying the file before closing it, so flush to avoid truncation
# we are copying the file before closing it, so flush to avoid truncation
temp_file
.
flush
()
temp_file
.
flush
()
...
...
transformers/modeling_albert.py
0 → 100644
View file @
8c276b9c
This diff is collapsed.
Click to expand it.
transformers/modeling_auto.py
View file @
8c276b9c
...
@@ -27,6 +27,7 @@ from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassi
...
@@ -27,6 +27,7 @@ from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassi
from
.modeling_xlm
import
XLMModel
,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
from
.modeling_xlm
import
XLMModel
,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_camembert
import
CamembertModel
,
CamembertForMaskedLM
,
CamembertForSequenceClassification
,
CamembertForMultipleChoice
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
...
@@ -48,6 +49,7 @@ class AutoModel(object):
...
@@ -48,6 +49,7 @@ class AutoModel(object):
The base model class to instantiate is selected as the first pattern matching
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
...
@@ -71,6 +73,7 @@ class AutoModel(object):
...
@@ -71,6 +73,7 @@ class AutoModel(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
...
@@ -112,6 +115,9 @@ class AutoModel(object):
...
@@ -112,6 +115,9 @@ class AutoModel(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -138,6 +144,8 @@ class AutoModel(object):
...
@@ -138,6 +144,8 @@ class AutoModel(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
RobertaModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
elif
'bert'
in
pretrained_model_name_or_path
:
...
@@ -172,6 +180,7 @@ class AutoModelWithLMHead(object):
...
@@ -172,6 +180,7 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
...
@@ -198,6 +207,7 @@ class AutoModelWithLMHead(object):
...
@@ -198,6 +207,7 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
...
@@ -237,6 +247,8 @@ class AutoModelWithLMHead(object):
...
@@ -237,6 +247,8 @@ class AutoModelWithLMHead(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
...
@@ -264,6 +276,8 @@ class AutoModelWithLMHead(object):
...
@@ -264,6 +276,8 @@ class AutoModelWithLMHead(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
RobertaForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
elif
'bert'
in
pretrained_model_name_or_path
:
...
@@ -298,6 +312,7 @@ class AutoModelForSequenceClassification(object):
...
@@ -298,6 +312,7 @@ class AutoModelForSequenceClassification(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `xlnet`: XLNetForSequenceClassification (XLNet model)
- contains `xlnet`: XLNetForSequenceClassification (XLNet model)
...
@@ -320,6 +335,7 @@ class AutoModelForSequenceClassification(object):
...
@@ -320,6 +335,7 @@ class AutoModelForSequenceClassification(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `xlnet`: XLNetForSequenceClassification (XLNet model)
- contains `xlnet`: XLNetForSequenceClassification (XLNet model)
...
@@ -357,6 +373,9 @@ class AutoModelForSequenceClassification(object):
...
@@ -357,6 +373,9 @@ class AutoModelForSequenceClassification(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -383,6 +402,8 @@ class AutoModelForSequenceClassification(object):
...
@@ -383,6 +402,8 @@ class AutoModelForSequenceClassification(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
RobertaForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
elif
'bert'
in
pretrained_model_name_or_path
:
...
...
transformers/modeling_bert.py
View file @
8c276b9c
...
@@ -278,7 +278,7 @@ class BertAttention(nn.Module):
...
@@ -278,7 +278,7 @@ class BertAttention(nn.Module):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
self
.
num_attention_heads
,
self
.
self
.
attention_head_size
)
mask
=
torch
.
ones
(
self
.
self
.
num_attention_heads
,
self
.
self
.
attention_head_size
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
# Convert to set and emove already pruned heads
heads
=
set
(
heads
)
-
self
.
pruned_heads
# Convert to set and
r
emove already pruned heads
for
head
in
heads
:
for
head
in
heads
:
# Compute how many pruned heads are before the head and move the index accordingly
# Compute how many pruned heads are before the head and move the index accordingly
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
...
@@ -597,7 +597,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -597,7 +597,7 @@ class BertModel(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute"
, add_special_tokens=True
)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
...
@@ -760,7 +760,7 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -760,7 +760,7 @@ class BertForPreTraining(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute"
, add_special_tokens=True
)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
outputs = model(input_ids)
prediction_scores, seq_relationship_scores = outputs[:2]
prediction_scores, seq_relationship_scores = outputs[:2]
...
@@ -836,7 +836,7 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -836,7 +836,7 @@ class BertForMaskedLM(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute"
, add_special_tokens=True
)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
loss, prediction_scores = outputs[:2]
...
@@ -919,7 +919,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
...
@@ -919,7 +919,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute"
, add_special_tokens=True
)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
outputs = model(input_ids)
seq_relationship_scores = outputs[0]
seq_relationship_scores = outputs[0]
...
@@ -984,7 +984,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
...
@@ -984,7 +984,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute"
, add_special_tokens=True
)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
loss, logits = outputs[:2]
...
@@ -1060,7 +1060,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -1060,7 +1060,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
input_ids = torch.tensor([tokenizer.encode(s
, add_special_tokens=True
) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
loss, classification_scores = outputs[:2]
...
@@ -1134,7 +1134,7 @@ class BertForTokenClassification(BertPreTrainedModel):
...
@@ -1134,7 +1134,7 @@ class BertForTokenClassification(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute"
, add_special_tokens=True
)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
loss, scores = outputs[:2]
...
...
transformers/modeling_camembert.py
View file @
8c276b9c
...
@@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function,
...
@@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function,
import
logging
import
logging
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
RobertaForMultipleChoice
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
RobertaForMultipleChoice
,
RobertaForTokenClassification
from
.configuration_camembert
import
CamembertConfig
from
.configuration_camembert
import
CamembertConfig
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
...
@@ -255,3 +255,39 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
...
@@ -255,3 +255,39 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
"""
"""
config_class
=
CamembertConfig
config_class
=
CamembertConfig
pretrained_model_archive_map
=
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
@
add_start_docstrings
(
"""CamemBERT Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
CAMEMBERT_START_DOCSTRING
,
CAMEMBERT_INPUTS_DOCSTRING
)
class
CamembertForTokenClassification
(
RobertaForTokenClassification
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
model = CamembertForTokenClassification.from_pretrained('camembert-base')
input_ids = torch.tensor(tokenizer.encode("J'aime le camembert !", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
config_class
=
CamembertConfig
pretrained_model_archive_map
=
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
transformers/modeling_distilbert.py
View file @
8c276b9c
...
@@ -44,6 +44,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -44,6 +44,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'distilbert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-pytorch_model.bin"
,
'distilbert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-pytorch_model.bin"
,
'distilbert-base-uncased-distilled-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-pytorch_model.bin"
,
'distilbert-base-uncased-distilled-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-pytorch_model.bin"
,
'distilbert-base-german-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-pytorch_model.bin"
,
'distilbert-base-german-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-pytorch_model.bin"
,
'distilbert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-pytorch_model.bin"
,
}
}
...
...
transformers/modeling_tf_albert.py
0 → 100644
View file @
8c276b9c
This diff is collapsed.
Click to expand it.
transformers/modeling_tf_auto.py
View file @
8c276b9c
...
@@ -109,6 +109,9 @@ class TFAutoModel(object):
...
@@ -109,6 +109,9 @@ class TFAutoModel(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object):
...
@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object):
...
@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object):
...
@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
...
transformers/modeling_tf_utils.py
View file @
8c276b9c
...
@@ -191,6 +191,9 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -191,6 +191,9 @@ class TFPreTrainedModel(tf.keras.Model):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -216,6 +219,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -216,6 +219,7 @@ class TFPreTrainedModel(tf.keras.Model):
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
from_pt
=
kwargs
.
pop
(
'from_pt'
,
False
)
from_pt
=
kwargs
.
pop
(
'from_pt'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
resume_download
=
kwargs
.
pop
(
'resume_download'
,
False
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
# Load config
# Load config
...
@@ -224,6 +228,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -224,6 +228,7 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_name_or_path
,
*
model_args
,
pretrained_model_name_or_path
,
*
model_args
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
force_download
=
force_download
,
force_download
=
force_download
,
resume_download
=
resume_download
,
**
kwargs
**
kwargs
)
)
else
:
else
:
...
@@ -251,7 +256,8 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -251,7 +256,8 @@ class TFPreTrainedModel(tf.keras.Model):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
)
except
EnvironmentError
as
e
:
except
EnvironmentError
as
e
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
logger
.
error
(
...
...
transformers/modeling_utils.py
View file @
8c276b9c
...
@@ -291,6 +291,9 @@ class PreTrainedModel(nn.Module):
...
@@ -291,6 +291,9 @@ class PreTrainedModel(nn.Module):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -315,11 +318,16 @@ class PreTrainedModel(nn.Module):
...
@@ -315,11 +318,16 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
"""
if
"albert"
in
pretrained_model_name_or_path
and
"v2"
in
pretrained_model_name_or_path
:
logger
.
warning
(
"There is currently an upstream reproducibility issue with ALBERT v2 models. Please see "
+
"https://github.com/google-research/google-research/issues/119 for more information."
)
config
=
kwargs
.
pop
(
'config'
,
None
)
config
=
kwargs
.
pop
(
'config'
,
None
)
state_dict
=
kwargs
.
pop
(
'state_dict'
,
None
)
state_dict
=
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
pop
(
'from_tf'
,
False
)
from_tf
=
kwargs
.
pop
(
'from_tf'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
force_download
=
kwargs
.
pop
(
'force_download'
,
False
)
resume_download
=
kwargs
.
pop
(
'resume_download'
,
False
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
proxies
=
kwargs
.
pop
(
'proxies'
,
None
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
...
@@ -329,6 +337,7 @@ class PreTrainedModel(nn.Module):
...
@@ -329,6 +337,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path
,
*
model_args
,
pretrained_model_name_or_path
,
*
model_args
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
cache_dir
=
cache_dir
,
return_unused_kwargs
=
True
,
force_download
=
force_download
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
proxies
=
proxies
,
**
kwargs
**
kwargs
)
)
...
@@ -361,7 +370,8 @@ class PreTrainedModel(nn.Module):
...
@@ -361,7 +370,8 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
except
EnvironmentError
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
msg
=
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
msg
=
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment