Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8a618e0a
Commit
8a618e0a
authored
Sep 25, 2019
by
thomwolf
Browse files
clean up __init__
parent
3b7fb48c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
38 deletions
+61
-38
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+19
-31
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+42
-7
No files found.
pytorch_transformers/__init__.py
View file @
8a618e0a
...
@@ -16,7 +16,21 @@ import logging
...
@@ -16,7 +16,21 @@ import logging
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
# Tokenizer
# Files and general utilities
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
,
is_tf_available
,
is_torch_available
)
from
.data
import
(
is_sklearn_available
,
InputExample
,
InputFeatures
,
DataProcessor
,
glue_output_modes
,
glue_convert_examples_to_features
,
glue_processors
,
glue_tasks_num_labels
)
if
is_sklearn_available
():
from
.data
import
glue_compute_metrics
# Tokenizers
from
.tokenization_utils
import
(
PreTrainedTokenizer
)
from
.tokenization_utils
import
(
PreTrainedTokenizer
)
from
.tokenization_auto
import
AutoTokenizer
from
.tokenization_auto
import
AutoTokenizer
from
.tokenization_bert
import
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
from
.tokenization_bert
import
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
...
@@ -41,13 +55,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
...
@@ -41,13 +55,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
from
.configuration_distilbert
import
DistilBertConfig
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_distilbert
import
DistilBertConfig
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
# Modeling
# Modeling
try
:
if
is_torch_available
():
import
torch
_torch_available
=
True
# pylint: disable=invalid-name
except
ImportError
:
_torch_available
=
False
# pylint: disable=invalid-name
if
_torch_available
:
logger
.
info
(
"PyTorch version {} available."
.
format
(
torch
.
__version__
))
logger
.
info
(
"PyTorch version {} available."
.
format
(
torch
.
__version__
))
from
.modeling_utils
import
(
PreTrainedModel
,
prune_layer
,
Conv1D
)
from
.modeling_utils
import
(
PreTrainedModel
,
prune_layer
,
Conv1D
)
...
@@ -87,14 +95,7 @@ if _torch_available:
...
@@ -87,14 +95,7 @@ if _torch_available:
# TensorFlow
# TensorFlow
try
:
if
is_tf_available
():
import
tensorflow
as
tf
assert
int
(
tf
.
__version__
[
0
])
>=
2
_tf_available
=
True
# pylint: disable=invalid-name
except
(
ImportError
,
AssertionError
):
_tf_available
=
False
# pylint: disable=invalid-name
if
_tf_available
:
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
...
@@ -151,7 +152,8 @@ if _tf_available:
...
@@ -151,7 +152,8 @@ if _tf_available:
load_distilbert_pt_weights_in_tf2
,
load_distilbert_pt_weights_in_tf2
,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
if
_tf_available
and
_torch_available
:
# TF 2.0 <=> PyTorch conversion utilities
if
is_tf_available
()
and
is_torch_available
():
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
load_pytorch_checkpoint_in_tf2_model
,
load_pytorch_checkpoint_in_tf2_model
,
load_pytorch_weights_in_tf2_model
,
load_pytorch_weights_in_tf2_model
,
...
@@ -159,17 +161,3 @@ if _tf_available and _torch_available:
...
@@ -159,17 +161,3 @@ if _tf_available and _torch_available:
load_tf2_checkpoint_in_pytorch_model
,
load_tf2_checkpoint_in_pytorch_model
,
load_tf2_weights_in_pytorch_model
,
load_tf2_weights_in_pytorch_model
,
load_tf2_model_in_pytorch_model
)
load_tf2_model_in_pytorch_model
)
# Files and general utilities
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
CONFIG_NAME
,
is_tf_available
,
is_torch_available
)
from
.data
import
(
is_sklearn_available
,
InputExample
,
InputFeatures
,
DataProcessor
,
glue_output_modes
,
glue_convert_examples_to_features
,
glue_processors
,
glue_tasks_num_labels
)
if
is_sklearn_available
():
from
.data
import
glue_compute_metrics
pytorch_transformers/tokenization_utils.py
View file @
8a618e0a
...
@@ -23,7 +23,7 @@ import six
...
@@ -23,7 +23,7 @@ import six
import
copy
import
copy
from
io
import
open
from
io
import
open
from
.file_utils
import
cached_path
,
is_tf_available
from
.file_utils
import
cached_path
,
is_tf_available
,
is_torch_available
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -690,7 +690,15 @@ class PreTrainedTokenizer(object):
...
@@ -690,7 +690,15 @@ class PreTrainedTokenizer(object):
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
raise
NotImplementedError
raise
NotImplementedError
def
encode
(
self
,
text
,
text_pair
=
None
,
add_special_tokens
=
False
,
**
kwargs
):
def
encode
(
self
,
text
,
text_pair
=
None
,
add_special_tokens
=
False
,
max_length
=
None
,
stride
=
0
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
**
kwargs
):
"""
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...
@@ -705,9 +713,24 @@ class PreTrainedTokenizer(object):
...
@@ -705,9 +713,24 @@ class PreTrainedTokenizer(object):
`convert_tokens_to_ids` method)
`convert_tokens_to_ids` method)
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
to their model.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
**kwargs: passed to the `self.tokenize()` method
"""
"""
encoded_inputs
=
self
.
encode_plus
(
text
,
text_pair
=
text_pair
,
add_special_tokens
=
add_special_tokens
,
**
kwargs
)
encoded_inputs
=
self
.
encode_plus
(
text
,
text_pair
=
text_pair
,
max_length
=
max_length
,
add_special_tokens
=
add_special_tokens
,
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
,
return_tensors
=
return_tensors
,
**
kwargs
)
return
encoded_inputs
[
"input_ids"
]
return
encoded_inputs
[
"input_ids"
]
...
@@ -718,10 +741,11 @@ class PreTrainedTokenizer(object):
...
@@ -718,10 +741,11 @@ class PreTrainedTokenizer(object):
max_length
=
None
,
max_length
=
None
,
stride
=
0
,
stride
=
0
,
truncate_first_sequence
=
True
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
**
kwargs
):
**
kwargs
):
"""
"""
Returns a dictionary containing the encoded sequence or sequence pair
. Other values can be returned by this
Returns a dictionary containing the encoded sequence or sequence pair
and additional informations:
method:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
Args:
Args:
text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
...
@@ -738,6 +762,8 @@ class PreTrainedTokenizer(object):
...
@@ -738,6 +762,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
**kwargs: passed to the `self.tokenize()` method
"""
"""
...
@@ -759,10 +785,12 @@ class PreTrainedTokenizer(object):
...
@@ -759,10 +785,12 @@ class PreTrainedTokenizer(object):
max_length
=
max_length
,
max_length
=
max_length
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
stride
=
stride
,
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
)
truncate_first_sequence
=
truncate_first_sequence
,
return_tensors
=
return_tensors
)
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
False
,
stride
=
0
,
truncate_first_sequence
=
True
):
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
False
,
stride
=
0
,
truncate_first_sequence
=
True
,
return_tensors
=
None
):
"""
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
It adds special tokens, truncates
...
@@ -782,6 +810,8 @@ class PreTrainedTokenizer(object):
...
@@ -782,6 +810,8 @@ class PreTrainedTokenizer(object):
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
Return:
Return:
a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
...
@@ -816,6 +846,11 @@ class PreTrainedTokenizer(object):
...
@@ -816,6 +846,11 @@ class PreTrainedTokenizer(object):
sequence
=
ids
+
pair_ids
if
pair
else
ids
sequence
=
ids
+
pair_ids
if
pair
else
ids
token_type_ids
=
[
0
]
*
len
(
ids
)
+
([
1
]
*
len
(
pair_ids
)
if
pair
else
[])
token_type_ids
=
[
0
]
*
len
(
ids
)
+
([
1
]
*
len
(
pair_ids
)
if
pair
else
[])
if
return_tensors
==
'tf'
and
is_tf_available
():
sequence
=
tf
.
constant
(
sequence
)
token_type_ids
=
tf
.
constant
(
token_type_ids
)
elif
return_tensors
=
'pt'
and
is
encoded_inputs
[
"input_ids"
]
=
sequence
encoded_inputs
[
"input_ids"
]
=
sequence
encoded_inputs
[
"token_type_ids"
]
=
token_type_ids
encoded_inputs
[
"token_type_ids"
]
=
token_type_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment