Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ce50305e
Unverified
Commit
ce50305e
authored
Dec 22, 2019
by
Aymeric Augustin
Committed by
GitHub
Dec 22, 2019
Browse files
Merge pull request #2270 from aaugustin/remove-python-2
Remove support for Python 2
parents
b6ea0f43
1a948d70
Changes
155
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
63 additions
and
179 deletions
+63
-179
src/transformers/modeling_xlnet.py
src/transformers/modeling_xlnet.py
+2
-5
src/transformers/optimization_tf.py
src/transformers/optimization_tf.py
+0
-1
src/transformers/pipelines.py
src/transformers/pipelines.py
+2
-3
src/transformers/tokenization_albert.py
src/transformers/tokenization_albert.py
+7
-29
src/transformers/tokenization_auto.py
src/transformers/tokenization_auto.py
+0
-1
src/transformers/tokenization_bert.py
src/transformers/tokenization_bert.py
+2
-4
src/transformers/tokenization_bert_japanese.py
src/transformers/tokenization_bert_japanese.py
+1
-7
src/transformers/tokenization_camembert.py
src/transformers/tokenization_camembert.py
+3
-3
src/transformers/tokenization_ctrl.py
src/transformers/tokenization_ctrl.py
+3
-4
src/transformers/tokenization_distilbert.py
src/transformers/tokenization_distilbert.py
+0
-1
src/transformers/tokenization_gpt2.py
src/transformers/tokenization_gpt2.py
+8
-24
src/transformers/tokenization_openai.py
src/transformers/tokenization_openai.py
+2
-3
src/transformers/tokenization_roberta.py
src/transformers/tokenization_roberta.py
+1
-10
src/transformers/tokenization_t5.py
src/transformers/tokenization_t5.py
+4
-19
src/transformers/tokenization_transfo_xl.py
src/transformers/tokenization_transfo_xl.py
+3
-9
src/transformers/tokenization_utils.py
src/transformers/tokenization_utils.py
+11
-18
src/transformers/tokenization_xlm.py
src/transformers/tokenization_xlm.py
+3
-4
src/transformers/tokenization_xlm_roberta.py
src/transformers/tokenization_xlm_roberta.py
+3
-3
src/transformers/tokenization_xlnet.py
src/transformers/tokenization_xlnet.py
+7
-29
templates/adding_a_new_example_script/run_xxx.py
templates/adding_a_new_example_script/run_xxx.py
+1
-2
No files found.
src/transformers/modeling_xlnet.py
View file @
ce50305e
...
@@ -15,11 +15,10 @@
...
@@ -15,11 +15,10 @@
# limitations under the License.
# limitations under the License.
""" PyTorch XLNet model.
""" PyTorch XLNet model.
"""
"""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
math
import
math
import
sys
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -420,9 +419,7 @@ class XLNetFeedForward(nn.Module):
...
@@ -420,9 +419,7 @@ class XLNetFeedForward(nn.Module):
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
if
isinstance
(
config
.
ff_activation
,
str
)
or
(
if
isinstance
(
config
.
ff_activation
,
str
):
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
ff_activation
,
unicode
)
# noqa: F821
):
self
.
activation_function
=
ACT2FN
[
config
.
ff_activation
]
self
.
activation_function
=
ACT2FN
[
config
.
ff_activation
]
else
:
else
:
self
.
activation_function
=
config
.
ff_activation
self
.
activation_function
=
config
.
ff_activation
...
...
src/transformers/optimization_tf.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# ==============================================================================
# ==============================================================================
"""Functions and classes related to optimization (weight updates)."""
"""Functions and classes related to optimization (weight updates)."""
from
__future__
import
absolute_import
,
division
,
print_function
import
re
import
re
...
...
src/transformers/pipelines.py
View file @
ce50305e
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
csv
import
csv
import
json
import
json
...
@@ -26,7 +26,6 @@ from os.path import abspath, exists
...
@@ -26,7 +26,6 @@ from os.path import abspath, exists
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
six
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AutoConfig
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AutoConfig
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
...
@@ -939,7 +938,7 @@ def pipeline(
...
@@ -939,7 +938,7 @@ def pipeline(
modelcard
=
config
modelcard
=
config
# Instantiate tokenizer if needed
# Instantiate tokenizer if needed
if
isinstance
(
tokenizer
,
s
ix
.
string_types
):
if
isinstance
(
tokenizer
,
s
tr
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
)
# Instantiate config if needed
# Instantiate config if needed
...
...
src/transformers/tokenization_albert.py
View file @
ce50305e
...
@@ -13,15 +13,13 @@
...
@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" Tokenization classes for ALBERT model."""
""" Tokenization classes for ALBERT model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
os
import
os
import
unicodedata
import
unicodedata
from
shutil
import
copyfile
from
shutil
import
copyfile
import
six
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
...
@@ -139,9 +137,6 @@ class AlbertTokenizer(PreTrainedTokenizer):
...
@@ -139,9 +137,6 @@ class AlbertTokenizer(PreTrainedTokenizer):
outputs
=
inputs
outputs
=
inputs
outputs
=
outputs
.
replace
(
"``"
,
'"'
).
replace
(
"''"
,
'"'
)
outputs
=
outputs
.
replace
(
"``"
,
'"'
).
replace
(
"''"
,
'"'
)
if
six
.
PY2
and
isinstance
(
outputs
,
str
):
outputs
=
outputs
.
decode
(
"utf-8"
)
if
not
self
.
keep_accents
:
if
not
self
.
keep_accents
:
outputs
=
unicodedata
.
normalize
(
"NFKD"
,
outputs
)
outputs
=
unicodedata
.
normalize
(
"NFKD"
,
outputs
)
outputs
=
""
.
join
([
c
for
c
in
outputs
if
not
unicodedata
.
combining
(
c
)])
outputs
=
""
.
join
([
c
for
c
in
outputs
if
not
unicodedata
.
combining
(
c
)])
...
@@ -150,14 +145,9 @@ class AlbertTokenizer(PreTrainedTokenizer):
...
@@ -150,14 +145,9 @@ class AlbertTokenizer(PreTrainedTokenizer):
return
outputs
return
outputs
def
_tokenize
(
self
,
text
,
return_unicode
=
True
,
sample
=
False
):
def
_tokenize
(
self
,
text
,
sample
=
False
):
""" Tokenize a string.
""" Tokenize a string. """
return_unicode is used only for py2
"""
text
=
self
.
preprocess_text
(
text
)
text
=
self
.
preprocess_text
(
text
)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
if
six
.
PY2
and
isinstance
(
text
,
unicode
):
# noqa: F821
text
=
text
.
encode
(
"utf-8"
)
if
not
sample
:
if
not
sample
:
pieces
=
self
.
sp_model
.
EncodeAsPieces
(
text
)
pieces
=
self
.
sp_model
.
EncodeAsPieces
(
text
)
...
@@ -177,27 +167,15 @@ class AlbertTokenizer(PreTrainedTokenizer):
...
@@ -177,27 +167,15 @@ class AlbertTokenizer(PreTrainedTokenizer):
else
:
else
:
new_pieces
.
append
(
piece
)
new_pieces
.
append
(
piece
)
# note(zhiliny): convert back to unicode for py2
if
six
.
PY2
and
return_unicode
:
ret_pieces
=
[]
for
piece
in
new_pieces
:
if
isinstance
(
piece
,
str
):
piece
=
piece
.
decode
(
"utf-8"
)
ret_pieces
.
append
(
piece
)
new_pieces
=
ret_pieces
return
new_pieces
return
new_pieces
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
return
self
.
sp_model
.
PieceToId
(
token
)
return
self
.
sp_model
.
PieceToId
(
token
)
def
_convert_id_to_token
(
self
,
index
,
return_unicode
=
True
):
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
"""Converts an index (integer) in a token (str) using the vocab."""
token
=
self
.
sp_model
.
IdToPiece
(
index
)
return
self
.
sp_model
.
IdToPiece
(
index
)
if
six
.
PY2
and
return_unicode
and
isinstance
(
token
,
str
):
token
=
token
.
decode
(
"utf-8"
)
return
token
def
convert_tokens_to_string
(
self
,
tokens
):
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
...
...
src/transformers/tokenization_auto.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
""" Auto Model class. """
""" Auto Model class. """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
...
...
src/transformers/tokenization_bert.py
View file @
ce50305e
...
@@ -14,13 +14,11 @@
...
@@ -14,13 +14,11 @@
# limitations under the License.
# limitations under the License.
"""Tokenization classes."""
"""Tokenization classes."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
collections
import
logging
import
logging
import
os
import
os
import
unicodedata
import
unicodedata
from
io
import
open
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
...
@@ -203,11 +201,11 @@ class BertTokenizer(PreTrainedTokenizer):
...
@@ -203,11 +201,11 @@ class BertTokenizer(PreTrainedTokenizer):
return
split_tokens
return
split_tokens
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
return
self
.
vocab
.
get
(
token
,
self
.
vocab
.
get
(
self
.
unk_token
))
return
self
.
vocab
.
get
(
token
,
self
.
vocab
.
get
(
self
.
unk_token
))
def
_convert_id_to_token
(
self
,
index
):
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str
ing/unicode
) using the vocab."""
"""Converts an index (integer) in a token (str) using the vocab."""
return
self
.
ids_to_tokens
.
get
(
index
,
self
.
unk_token
)
return
self
.
ids_to_tokens
.
get
(
index
,
self
.
unk_token
)
def
convert_tokens_to_string
(
self
,
tokens
):
def
convert_tokens_to_string
(
self
,
tokens
):
...
...
src/transformers/tokenization_bert_japanese.py
View file @
ce50305e
...
@@ -14,15 +14,12 @@
...
@@ -14,15 +14,12 @@
# limitations under the License.
# limitations under the License.
"""Tokenization classes."""
"""Tokenization classes."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
collections
import
logging
import
logging
import
os
import
os
import
unicodedata
import
unicodedata
import
six
from
.tokenization_bert
import
BasicTokenizer
,
BertTokenizer
,
WordpieceTokenizer
,
load_vocab
from
.tokenization_bert
import
BasicTokenizer
,
BertTokenizer
,
WordpieceTokenizer
,
load_vocab
...
@@ -195,10 +192,7 @@ class MecabTokenizer(object):
...
@@ -195,10 +192,7 @@ class MecabTokenizer(object):
never_split
=
self
.
never_split
+
(
never_split
if
never_split
is
not
None
else
[])
never_split
=
self
.
never_split
+
(
never_split
if
never_split
is
not
None
else
[])
tokens
=
[]
tokens
=
[]
if
six
.
PY2
:
mecab_output
=
self
.
mecab
.
parse
(
text
)
mecab_output
=
self
.
mecab
.
parse
(
text
.
encode
(
"utf-8"
)).
decode
(
"utf-8"
)
else
:
mecab_output
=
self
.
mecab
.
parse
(
text
)
cursor
=
0
cursor
=
0
for
line
in
mecab_output
.
split
(
"
\n
"
):
for
line
in
mecab_output
.
split
(
"
\n
"
):
...
...
src/transformers/tokenization_camembert.py
View file @
ce50305e
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
""" Tokenization classes for Camembert model."""
""" Tokenization classes for Camembert model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
os
import
os
...
@@ -155,7 +155,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
...
@@ -155,7 +155,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
return
self
.
sp_model
.
EncodeAsPieces
(
text
)
return
self
.
sp_model
.
EncodeAsPieces
(
text
)
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
if
token
in
self
.
fairseq_tokens_to_ids
:
if
token
in
self
.
fairseq_tokens_to_ids
:
return
self
.
fairseq_tokens_to_ids
[
token
]
return
self
.
fairseq_tokens_to_ids
[
token
]
elif
self
.
sp_model
.
PieceToId
(
token
)
==
0
:
elif
self
.
sp_model
.
PieceToId
(
token
)
==
0
:
...
@@ -164,7 +164,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
...
@@ -164,7 +164,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
return
self
.
fairseq_offset
+
self
.
sp_model
.
PieceToId
(
token
)
return
self
.
fairseq_offset
+
self
.
sp_model
.
PieceToId
(
token
)
def
_convert_id_to_token
(
self
,
index
):
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str
ing/unicode
) using the vocab."""
"""Converts an index (integer) in a token (str) using the vocab."""
if
index
in
self
.
fairseq_ids_to_tokens
:
if
index
in
self
.
fairseq_ids_to_tokens
:
return
self
.
fairseq_ids_to_tokens
[
index
]
return
self
.
fairseq_ids_to_tokens
[
index
]
return
self
.
sp_model
.
IdToPiece
(
index
-
self
.
fairseq_offset
)
return
self
.
sp_model
.
IdToPiece
(
index
-
self
.
fairseq_offset
)
...
...
src/transformers/tokenization_ctrl.py
View file @
ce50305e
...
@@ -13,12 +13,11 @@
...
@@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tokenization classes for Salesforce CTRL."""
"""Tokenization classes for Salesforce CTRL."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
json
import
logging
import
logging
import
os
import
os
from
io
import
open
import
regex
as
re
import
regex
as
re
...
@@ -204,11 +203,11 @@ class CTRLTokenizer(PreTrainedTokenizer):
...
@@ -204,11 +203,11 @@ class CTRLTokenizer(PreTrainedTokenizer):
return
split_tokens
return
split_tokens
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
def
_convert_id_to_token
(
self
,
index
):
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str
ing/unicode
) using the vocab."""
"""Converts an index (integer) in a token (str) using the vocab."""
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
def
convert_tokens_to_string
(
self
,
tokens
):
def
convert_tokens_to_string
(
self
,
tokens
):
...
...
src/transformers/tokenization_distilbert.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
"""Tokenization classes for DistilBERT."""
"""Tokenization classes for DistilBERT."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
...
...
src/transformers/tokenization_gpt2.py
View file @
ce50305e
...
@@ -13,28 +13,18 @@
...
@@ -13,28 +13,18 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
json
import
logging
import
logging
import
os
import
os
import
sys
from
functools
import
lru_cache
from
io
import
open
import
regex
as
re
import
regex
as
re
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
try
:
from
functools
import
lru_cache
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def
lru_cache
():
return
lambda
func
:
func
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
VOCAB_FILES_NAMES
=
{
VOCAB_FILES_NAMES
=
{
...
@@ -80,7 +70,6 @@ def bytes_to_unicode():
...
@@ -80,7 +70,6 @@ def bytes_to_unicode():
This is a signficant percentage of your normal, say, 32K bpe vocab.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
"""
"""
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
# noqa: F821
bs
=
(
bs
=
(
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
)
)
...
@@ -91,7 +80,7 @@ def bytes_to_unicode():
...
@@ -91,7 +80,7 @@ def bytes_to_unicode():
bs
.
append
(
b
)
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
n
+=
1
cs
=
[
_
chr
(
n
)
for
n
in
cs
]
cs
=
[
chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
return
dict
(
zip
(
bs
,
cs
))
...
@@ -212,23 +201,18 @@ class GPT2Tokenizer(PreTrainedTokenizer):
...
@@ -212,23 +201,18 @@ class GPT2Tokenizer(PreTrainedTokenizer):
bpe_tokens
=
[]
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
for
token
in
re
.
findall
(
self
.
pat
,
text
):
if
sys
.
version_info
[
0
]
==
2
:
token
=
""
.
join
(
token
=
""
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
"utf-8"
)
self
.
byte_encoder
[
ord
(
b
)]
for
b
in
token
)
# Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
)
# Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
else
:
token
=
""
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
"utf-8"
)
)
# Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
bpe_tokens
.
extend
(
bpe_token
for
bpe_token
in
self
.
bpe
(
token
).
split
(
" "
))
bpe_tokens
.
extend
(
bpe_token
for
bpe_token
in
self
.
bpe
(
token
).
split
(
" "
))
return
bpe_tokens
return
bpe_tokens
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
def
_convert_id_to_token
(
self
,
index
):
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str
ing/unicode
) using the vocab."""
"""Converts an index (integer) in a token (str) using the vocab."""
return
self
.
decoder
.
get
(
index
)
return
self
.
decoder
.
get
(
index
)
def
convert_tokens_to_string
(
self
,
tokens
):
def
convert_tokens_to_string
(
self
,
tokens
):
...
...
src/transformers/tokenization_openai.py
View file @
ce50305e
...
@@ -13,13 +13,12 @@
...
@@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
json
import
logging
import
logging
import
os
import
os
import
re
import
re
from
io
import
open
from
.tokenization_bert
import
BasicTokenizer
from
.tokenization_bert
import
BasicTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
...
@@ -177,7 +176,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
...
@@ -177,7 +176,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return
split_tokens
return
split_tokens
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
def
_convert_id_to_token
(
self
,
index
):
def
_convert_id_to_token
(
self
,
index
):
...
...
src/transformers/tokenization_roberta.py
View file @
ce50305e
...
@@ -13,22 +13,13 @@
...
@@ -13,22 +13,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tokenization classes for RoBERTa."""
"""Tokenization classes for RoBERTa."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
from
.tokenization_gpt2
import
GPT2Tokenizer
from
.tokenization_gpt2
import
GPT2Tokenizer
try
:
from
functools
import
lru_cache
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def
lru_cache
():
return
lambda
func
:
func
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
VOCAB_FILES_NAMES
=
{
VOCAB_FILES_NAMES
=
{
...
...
src/transformers/tokenization_t5.py
View file @
ce50305e
...
@@ -14,15 +14,12 @@
...
@@ -14,15 +14,12 @@
# limitations under the License.
# limitations under the License.
""" Tokenization class for model T5."""
""" Tokenization class for model T5."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
os
import
os
import
re
import
re
from
shutil
import
copyfile
from
shutil
import
copyfile
import
six
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
...
@@ -138,41 +135,29 @@ class T5Tokenizer(PreTrainedTokenizer):
...
@@ -138,41 +135,29 @@ class T5Tokenizer(PreTrainedTokenizer):
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
self
.
vocab_file
)
self
.
sp_model
.
Load
(
self
.
vocab_file
)
def
_tokenize
(
self
,
text
,
return_unicode
=
True
,
sample
=
False
):
def
_tokenize
(
self
,
text
,
sample
=
False
):
""" Take as input a string and return a list of strings (tokens) for words/sub-words
""" Take as input a string and return a list of strings (tokens) for words/sub-words
"""
"""
if
not
sample
:
if
not
sample
:
pieces
=
self
.
sp_model
.
EncodeAsPieces
(
text
)
pieces
=
self
.
sp_model
.
EncodeAsPieces
(
text
)
else
:
else
:
pieces
=
self
.
sp_model
.
SampleEncodeAsPieces
(
text
,
64
,
0.1
)
pieces
=
self
.
sp_model
.
SampleEncodeAsPieces
(
text
,
64
,
0.1
)
# convert back to unicode for py2
if
six
.
PY2
and
return_unicode
:
ret_pieces
=
[]
for
piece
in
pieces
:
if
isinstance
(
piece
,
str
):
piece
=
piece
.
decode
(
"utf-8"
)
ret_pieces
.
append
(
piece
)
pieces
=
ret_pieces
return
pieces
return
pieces
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
if
token
.
startswith
(
"<extra_id_"
):
if
token
.
startswith
(
"<extra_id_"
):
match
=
re
.
match
(
r
"<extra_id_(\d+)>"
,
token
)
match
=
re
.
match
(
r
"<extra_id_(\d+)>"
,
token
)
num
=
int
(
match
.
group
(
1
))
num
=
int
(
match
.
group
(
1
))
return
self
.
vocab_size
-
num
-
1
return
self
.
vocab_size
-
num
-
1
return
self
.
sp_model
.
piece_to_id
(
token
)
return
self
.
sp_model
.
piece_to_id
(
token
)
def
_convert_id_to_token
(
self
,
index
,
return_unicode
=
True
):
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str
ing/unicode
) using the vocab."""
"""Converts an index (integer) in a token (str) using the vocab."""
if
index
<
self
.
sp_model
.
get_piece_size
():
if
index
<
self
.
sp_model
.
get_piece_size
():
token
=
self
.
sp_model
.
IdToPiece
(
index
)
token
=
self
.
sp_model
.
IdToPiece
(
index
)
else
:
else
:
token
=
"<extra_id_{}>"
.
format
(
self
.
vocab_size
-
1
-
index
)
token
=
"<extra_id_{}>"
.
format
(
self
.
vocab_size
-
1
-
index
)
if
six
.
PY2
and
return_unicode
and
isinstance
(
token
,
str
):
token
=
token
.
decode
(
"utf-8"
)
return
token
return
token
def
convert_tokens_to_string
(
self
,
tokens
):
def
convert_tokens_to_string
(
self
,
tokens
):
...
...
src/transformers/tokenization_transfo_xl.py
View file @
ce50305e
...
@@ -16,14 +16,13 @@
...
@@ -16,14 +16,13 @@
""" Tokenization classes for Transformer XL model.
""" Tokenization classes for Transformer XL model.
Adapted from https://github.com/kimiyoung/transformer-xl.
Adapted from https://github.com/kimiyoung/transformer-xl.
"""
"""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
glob
import
glob
import
logging
import
logging
import
os
import
os
import
sys
import
pickle
from
collections
import
Counter
,
OrderedDict
from
collections
import
Counter
,
OrderedDict
from
io
import
open
import
numpy
as
np
import
numpy
as
np
...
@@ -36,11 +35,6 @@ try:
...
@@ -36,11 +35,6 @@ try:
except
ImportError
:
except
ImportError
:
pass
pass
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -238,7 +232,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
...
@@ -238,7 +232,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return
self
.
idx2sym
[
idx
]
return
self
.
idx2sym
[
idx
]
def
_convert_token_to_id
(
self
,
sym
):
def
_convert_token_to_id
(
self
,
sym
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
if
sym
in
self
.
sym2idx
:
if
sym
in
self
.
sym2idx
:
return
self
.
sym2idx
[
sym
]
return
self
.
sym2idx
[
sym
]
else
:
else
:
...
...
src/transformers/tokenization_utils.py
View file @
ce50305e
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
copy
import
itertools
import
itertools
...
@@ -21,9 +21,6 @@ import json
...
@@ -21,9 +21,6 @@ import json
import
logging
import
logging
import
os
import
os
import
re
import
re
from
io
import
open
import
six
from
.file_utils
import
cached_path
,
hf_bucket_url
,
is_remote_url
,
is_tf_available
,
is_torch_available
from
.file_utils
import
cached_path
,
hf_bucket_url
,
is_remote_url
,
is_tf_available
,
is_torch_available
...
@@ -251,11 +248,9 @@ class PreTrainedTokenizer(object):
...
@@ -251,11 +248,9 @@ class PreTrainedTokenizer(object):
for
key
,
value
in
kwargs
.
items
():
for
key
,
value
in
kwargs
.
items
():
if
key
in
self
.
SPECIAL_TOKENS_ATTRIBUTES
:
if
key
in
self
.
SPECIAL_TOKENS_ATTRIBUTES
:
if
key
==
"additional_special_tokens"
:
if
key
==
"additional_special_tokens"
:
assert
isinstance
(
value
,
(
list
,
tuple
))
and
all
(
assert
isinstance
(
value
,
(
list
,
tuple
))
and
all
(
isinstance
(
t
,
str
)
for
t
in
value
)
isinstance
(
t
,
str
)
or
(
six
.
PY2
and
isinstance
(
t
,
unicode
))
for
t
in
value
# noqa: F821
)
else
:
else
:
assert
isinstance
(
value
,
str
)
or
(
six
.
PY2
and
isinstance
(
value
,
unicode
))
# noqa: F821
assert
isinstance
(
value
,
str
)
setattr
(
self
,
key
,
value
)
setattr
(
self
,
key
,
value
)
@
classmethod
@
classmethod
...
@@ -567,7 +562,7 @@ class PreTrainedTokenizer(object):
...
@@ -567,7 +562,7 @@ class PreTrainedTokenizer(object):
to_add_tokens
=
[]
to_add_tokens
=
[]
for
token
in
new_tokens
:
for
token
in
new_tokens
:
assert
isinstance
(
token
,
str
)
or
(
six
.
PY2
and
isinstance
(
token
,
unicode
))
# noqa: F821
assert
isinstance
(
token
,
str
)
if
self
.
init_kwargs
.
get
(
"do_lower_case"
,
False
)
and
token
not
in
self
.
all_special_tokens
:
if
self
.
init_kwargs
.
get
(
"do_lower_case"
,
False
)
and
token
not
in
self
.
all_special_tokens
:
token
=
token
.
lower
()
token
=
token
.
lower
()
if
(
if
(
...
@@ -649,12 +644,10 @@ class PreTrainedTokenizer(object):
...
@@ -649,12 +644,10 @@ class PreTrainedTokenizer(object):
for
key
,
value
in
special_tokens_dict
.
items
():
for
key
,
value
in
special_tokens_dict
.
items
():
assert
key
in
self
.
SPECIAL_TOKENS_ATTRIBUTES
assert
key
in
self
.
SPECIAL_TOKENS_ATTRIBUTES
if
key
==
"additional_special_tokens"
:
if
key
==
"additional_special_tokens"
:
assert
isinstance
(
value
,
(
list
,
tuple
))
and
all
(
assert
isinstance
(
value
,
(
list
,
tuple
))
and
all
(
isinstance
(
t
,
str
)
for
t
in
value
)
isinstance
(
t
,
str
)
or
(
six
.
PY2
and
isinstance
(
t
,
unicode
))
for
t
in
value
# noqa: F821
)
added_tokens
+=
self
.
add_tokens
(
value
)
added_tokens
+=
self
.
add_tokens
(
value
)
else
:
else
:
assert
isinstance
(
value
,
str
)
or
(
six
.
PY2
and
isinstance
(
value
,
unicode
))
# noqa: F821
assert
isinstance
(
value
,
str
)
added_tokens
+=
self
.
add_tokens
([
value
])
added_tokens
+=
self
.
add_tokens
([
value
])
logger
.
info
(
"Assigning %s to the %s key of the tokenizer"
,
value
,
key
)
logger
.
info
(
"Assigning %s to the %s key of the tokenizer"
,
value
,
key
)
setattr
(
self
,
key
,
value
)
setattr
(
self
,
key
,
value
)
...
@@ -740,13 +733,13 @@ class PreTrainedTokenizer(object):
...
@@ -740,13 +733,13 @@ class PreTrainedTokenizer(object):
raise
NotImplementedError
raise
NotImplementedError
def
convert_tokens_to_ids
(
self
,
tokens
):
def
convert_tokens_to_ids
(
self
,
tokens
):
""" Converts a single token, or a sequence of tokens, (str
/unicode
) in a single integer id
""" Converts a single token, or a sequence of tokens, (str) in a single integer id
(resp. a sequence of ids), using the vocabulary.
(resp. a sequence of ids), using the vocabulary.
"""
"""
if
tokens
is
None
:
if
tokens
is
None
:
return
None
return
None
if
isinstance
(
tokens
,
str
)
or
(
six
.
PY2
and
isinstance
(
tokens
,
unicode
)):
# noqa: F821
if
isinstance
(
tokens
,
str
)
:
return
self
.
_convert_token_to_id_with_added_voc
(
tokens
)
return
self
.
_convert_token_to_id_with_added_voc
(
tokens
)
ids
=
[]
ids
=
[]
...
@@ -901,9 +894,9 @@ class PreTrainedTokenizer(object):
...
@@ -901,9 +894,9 @@ class PreTrainedTokenizer(object):
"""
"""
def
get_input_ids
(
text
):
def
get_input_ids
(
text
):
if
isinstance
(
text
,
s
ix
.
string_types
):
if
isinstance
(
text
,
s
tr
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
elif
isinstance
(
text
,
(
list
,
tuple
))
and
len
(
text
)
>
0
and
isinstance
(
text
[
0
],
s
ix
.
string_types
):
elif
isinstance
(
text
,
(
list
,
tuple
))
and
len
(
text
)
>
0
and
isinstance
(
text
[
0
],
s
tr
):
return
self
.
convert_tokens_to_ids
(
text
)
return
self
.
convert_tokens_to_ids
(
text
)
elif
isinstance
(
text
,
(
list
,
tuple
))
and
len
(
text
)
>
0
and
isinstance
(
text
[
0
],
int
):
elif
isinstance
(
text
,
(
list
,
tuple
))
and
len
(
text
)
>
0
and
isinstance
(
text
[
0
],
int
):
return
text
return
text
...
@@ -1297,7 +1290,7 @@ class PreTrainedTokenizer(object):
...
@@ -1297,7 +1290,7 @@ class PreTrainedTokenizer(object):
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
""" Converts a single index or a sequence of indices (integers) in a token "
""" Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str
/unicode
), using the vocabulary and added tokens.
(resp.) a sequence of tokens (str), using the vocabulary and added tokens.
Args:
Args:
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
...
...
src/transformers/tokenization_xlm.py
View file @
ce50305e
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tokenization classes for XLM."""
"""Tokenization classes for XLM."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
json
import
logging
import
logging
...
@@ -21,7 +21,6 @@ import os
...
@@ -21,7 +21,6 @@ import os
import
re
import
re
import
sys
import
sys
import
unicodedata
import
unicodedata
from
io
import
open
import
sacremoses
as
sm
import
sacremoses
as
sm
...
@@ -798,11 +797,11 @@ class XLMTokenizer(PreTrainedTokenizer):
...
@@ -798,11 +797,11 @@ class XLMTokenizer(PreTrainedTokenizer):
return
split_tokens
return
split_tokens
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
def
_convert_id_to_token
(
self
,
index
):
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str
ing/unicode
) using the vocab."""
"""Converts an index (integer) in a token (str) using the vocab."""
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
def
convert_tokens_to_string
(
self
,
tokens
):
def
convert_tokens_to_string
(
self
,
tokens
):
...
...
src/transformers/tokenization_xlm_roberta.py
View file @
ce50305e
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
""" Tokenization classes for XLM-RoBERTa model."""
""" Tokenization classes for XLM-RoBERTa model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
os
import
os
...
@@ -171,13 +171,13 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
...
@@ -171,13 +171,13 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
return
self
.
sp_model
.
EncodeAsPieces
(
text
)
return
self
.
sp_model
.
EncodeAsPieces
(
text
)
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
if
token
in
self
.
fairseq_tokens_to_ids
:
if
token
in
self
.
fairseq_tokens_to_ids
:
return
self
.
fairseq_tokens_to_ids
[
token
]
return
self
.
fairseq_tokens_to_ids
[
token
]
return
self
.
sp_model
.
PieceToId
(
token
)
+
self
.
fairseq_offset
return
self
.
sp_model
.
PieceToId
(
token
)
+
self
.
fairseq_offset
def
_convert_id_to_token
(
self
,
index
):
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str
ing/unicode
) using the vocab."""
"""Converts an index (integer) in a token (str) using the vocab."""
if
index
in
self
.
fairseq_ids_to_tokens
:
if
index
in
self
.
fairseq_ids_to_tokens
:
return
self
.
fairseq_ids_to_tokens
[
index
]
return
self
.
fairseq_ids_to_tokens
[
index
]
return
self
.
sp_model
.
IdToPiece
(
index
-
self
.
fairseq_offset
)
return
self
.
sp_model
.
IdToPiece
(
index
-
self
.
fairseq_offset
)
...
...
src/transformers/tokenization_xlnet.py
View file @
ce50305e
...
@@ -13,15 +13,13 @@
...
@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" Tokenization classes for XLNet model."""
""" Tokenization classes for XLNet model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
os
import
os
import
unicodedata
import
unicodedata
from
shutil
import
copyfile
from
shutil
import
copyfile
import
six
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
...
@@ -139,9 +137,6 @@ class XLNetTokenizer(PreTrainedTokenizer):
...
@@ -139,9 +137,6 @@ class XLNetTokenizer(PreTrainedTokenizer):
outputs
=
inputs
outputs
=
inputs
outputs
=
outputs
.
replace
(
"``"
,
'"'
).
replace
(
"''"
,
'"'
)
outputs
=
outputs
.
replace
(
"``"
,
'"'
).
replace
(
"''"
,
'"'
)
if
six
.
PY2
and
isinstance
(
outputs
,
str
):
outputs
=
outputs
.
decode
(
"utf-8"
)
if
not
self
.
keep_accents
:
if
not
self
.
keep_accents
:
outputs
=
unicodedata
.
normalize
(
"NFKD"
,
outputs
)
outputs
=
unicodedata
.
normalize
(
"NFKD"
,
outputs
)
outputs
=
""
.
join
([
c
for
c
in
outputs
if
not
unicodedata
.
combining
(
c
)])
outputs
=
""
.
join
([
c
for
c
in
outputs
if
not
unicodedata
.
combining
(
c
)])
...
@@ -150,14 +145,9 @@ class XLNetTokenizer(PreTrainedTokenizer):
...
@@ -150,14 +145,9 @@ class XLNetTokenizer(PreTrainedTokenizer):
return
outputs
return
outputs
def
_tokenize
(
self
,
text
,
return_unicode
=
True
,
sample
=
False
):
def
_tokenize
(
self
,
text
,
sample
=
False
):
""" Tokenize a string.
""" Tokenize a string. """
return_unicode is used only for py2
"""
text
=
self
.
preprocess_text
(
text
)
text
=
self
.
preprocess_text
(
text
)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
if
six
.
PY2
and
isinstance
(
text
,
unicode
):
# noqa: F821
text
=
text
.
encode
(
"utf-8"
)
if
not
sample
:
if
not
sample
:
pieces
=
self
.
sp_model
.
EncodeAsPieces
(
text
)
pieces
=
self
.
sp_model
.
EncodeAsPieces
(
text
)
...
@@ -177,27 +167,15 @@ class XLNetTokenizer(PreTrainedTokenizer):
...
@@ -177,27 +167,15 @@ class XLNetTokenizer(PreTrainedTokenizer):
else
:
else
:
new_pieces
.
append
(
piece
)
new_pieces
.
append
(
piece
)
# note(zhiliny): convert back to unicode for py2
if
six
.
PY2
and
return_unicode
:
ret_pieces
=
[]
for
piece
in
new_pieces
:
if
isinstance
(
piece
,
str
):
piece
=
piece
.
decode
(
"utf-8"
)
ret_pieces
.
append
(
piece
)
new_pieces
=
ret_pieces
return
new_pieces
return
new_pieces
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str
/unicode
) in an id using the vocab. """
""" Converts a token (str) in an id using the vocab. """
return
self
.
sp_model
.
PieceToId
(
token
)
return
self
.
sp_model
.
PieceToId
(
token
)
def
_convert_id_to_token
(
self
,
index
,
return_unicode
=
True
):
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
"""Converts an index (integer) in a token (str) using the vocab."""
token
=
self
.
sp_model
.
IdToPiece
(
index
)
return
self
.
sp_model
.
IdToPiece
(
index
)
if
six
.
PY2
and
return_unicode
and
isinstance
(
token
,
str
):
token
=
token
.
decode
(
"utf-8"
)
return
token
def
convert_tokens_to_string
(
self
,
tokens
):
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
...
...
templates/adding_a_new_example_script/run_xxx.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
""" Finetuning the library models for task XXX."""
""" Finetuning the library models for task XXX."""
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
argparse
import
glob
import
glob
...
@@ -156,7 +155,7 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -156,7 +155,7 @@ def train(args, train_dataset, model, tokenizer):
tr_loss
,
logging_loss
=
0.0
,
0.0
tr_loss
,
logging_loss
=
0.0
,
0.0
model
.
zero_grad
()
model
.
zero_grad
()
train_iterator
=
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
train_iterator
=
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
set_seed
(
args
)
# Added here for reproductibility
(even between python 2 and 3)
set_seed
(
args
)
# Added here for reproductibility
for
_
in
train_iterator
:
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
for
step
,
batch
in
enumerate
(
epoch_iterator
):
for
step
,
batch
in
enumerate
(
epoch_iterator
):
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment