Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6c41a8f5
Commit
6c41a8f5
authored
Aug 08, 2019
by
LysandreJik
Browse files
Encode and Decode are back in the superclass. They now handle sentence pairs special tokens.
parent
e367ac46
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
79 deletions
+81
-79
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+1
-2
pytorch_transformers/modeling_roberta.py
pytorch_transformers/modeling_roberta.py
+1
-2
pytorch_transformers/tokenization_roberta.py
pytorch_transformers/tokenization_roberta.py
+41
-67
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+38
-8
No files found.
pytorch_transformers/__init__.py
View file @
6c41a8f5
...
...
@@ -7,7 +7,6 @@ from .tokenization_gpt2 import GPT2Tokenizer
from
.tokenization_xlnet
import
XLNetTokenizer
,
SPIECE_UNDERLINE
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_utils
import
(
PreTrainedTokenizer
,
clean_up_tokenization
)
from
.tokenization_utils
import
(
PreTrainedTokenizer
)
...
...
@@ -39,7 +38,7 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_roberta
import
(
RobertaConfig
,
RobertaForMaskedLM
,
RobertaModel
,
from
.modeling_roberta
import
(
RobertaConfig
,
RobertaForMaskedLM
,
RobertaModel
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_utils
import
(
WEIGHTS_NAME
,
CONFIG_NAME
,
TF_WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_layer
,
Conv1D
)
...
...
pytorch_transformers/modeling_roberta.py
View file @
6c41a8f5
...
...
@@ -23,7 +23,7 @@ import logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
pytorch_transformers.modeling_bert
import
(
BertConfig
,
BertEmbeddings
,
BertLayerNorm
,
BertModel
,
...
...
@@ -144,7 +144,6 @@ class RobertaLMHead(nn.Module):
return
x
class
RobertaForSequenceClassification
(
BertPreTrainedModel
):
"""
Roberta Model with a classifier head on top.
...
...
pytorch_transformers/tokenization_roberta.py
View file @
6c41a8f5
...
...
@@ -21,18 +21,19 @@ import logging
import
re
from
io
import
open
import
six
import
os
from
.tokenization_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_gpt2
import
GPT2Tokenizer
logger
=
logging
.
getLogger
(
__name__
)
VOCAB
_FILES_NAMES
=
{
'
vocab
_file'
:
'dict.txt'
,
DICT
_FILES_NAMES
=
{
'
dict
_file'
:
'dict.txt'
,
}
PRETRAINED_
VOCAB
_FILES_MAP
=
{
'
vocab
_file'
:
PRETRAINED_
DICT
_FILES_MAP
=
{
'
dict
_file'
:
{
'roberta-base'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt"
,
'roberta-large'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt"
,
...
...
@@ -178,89 +179,62 @@ class RobertaTokenizer(PreTrainedTokenizer):
RoBERTa tokenizer. Peculiarities:
- GPT-2 tokenizer with a different integer mapping on top.
"""
vocab_files_names
=
VOCAB
_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_
VOCAB
_FILES_MAP
vocab_files_names
=
DICT
_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_
DICT
_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
vocab_file
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
**
kwargs
):
super
(
RobertaTokenizer
,
self
).
__init__
(
cls_token
=
bos_token
,
sep_token
=
eos_token
,
eos_token
=
eos_token
,
**
kwargs
)
def
__init__
(
self
,
dict_file
,
bpe_tokenizer
=
None
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
**
kwargs
):
super
(
RobertaTokenizer
,
self
).
__init__
(
cls_token
=
bos_token
,
sep_token
=
eos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
**
kwargs
)
self
.
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
)
self
.
dictionary
=
Dictionary
.
load
(
vocab_file
)
self
.
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
if
bpe_tokenizer
is
None
else
bpe_tokenizer
self
.
dictionary
=
Dictionary
.
load
(
dict_file
)
@
property
def
vocab_size
(
self
):
return
len
(
self
.
dictionary
.
indices
)
def
_tokenize
(
self
,
text
):
""" Use GPT-2 Tokenizer """
return
self
.
gpt2_tokenizer
.
_tokenize
(
text
)
def
encode
(
self
,
text
,
*
args
):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
"""
bpe_sentence
=
[
self
.
cls_token
]
+
\
self
.
gpt2_tokenizer
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
+
\
[
self
.
sep_token
]
if
len
(
args
):
for
additional_sentence
in
args
:
bpe_sentence
+=
[
self
.
sep_token
]
+
\
self
.
gpt2_tokenizer
.
convert_tokens_to_ids
(
self
.
tokenize
(
additional_sentence
))
+
\
[
self
.
sep_token
]
return
self
.
dictionary
.
encode_line
(
' '
.
join
([
str
(
token
)
for
token
in
bpe_sentence
]),
append_eos
=
False
)
def
decode
(
self
,
token_ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
Handles sentence pairs.
"""
filtered_tokens
=
self
.
convert_ids_to_tokens
(
token_ids
,
skip_special_tokens
=
skip_special_tokens
)
if
any
(
isinstance
(
element
,
list
)
for
element
in
filtered_tokens
):
texts
=
[]
for
element
in
filtered_tokens
:
text
=
self
.
convert_tokens_to_string
(
element
)
if
clean_up_tokenization_spaces
:
text
=
clean_up_tokenization
(
text
)
texts
.
append
(
text
)
return
texts
else
:
text
=
self
.
convert_tokens_to_string
(
filtered_tokens
)
if
clean_up_tokenization_spaces
:
text
=
clean_up_tokenization
(
text
)
return
text
def
_convert_token_to_id
(
self
,
token
):
if
self
.
dictionary
.
index
(
token
)
!=
3
:
return
self
.
dictionary
.
index
(
token
)
return
self
.
dictionary
.
index
(
str
(
self
.
gpt2_tokenizer
.
convert_tokens_to_ids
(
token
)))
def
_convert_id_to_token
(
self
,
index
):
symbol
=
self
.
dictionary
[
index
]
try
:
idx
=
int
(
symbol
)
return
self
.
gpt2_tokenizer
.
_convert_id_to_token
(
idx
)
except
:
except
ValueError
:
return
symbol
def
convert_tokens_to_string
(
self
,
tokens
):
return
self
.
gpt2_tokenizer
.
convert_tokens_to_string
(
tokens
)
def
convert_tokens_to_ids
(
self
,
tokens
,
no_sep_cls_tokens
=
False
):
cls
=
[
self
.
_convert_token_to_id
(
self
.
cls_token
)]
tokens
=
super
().
convert_tokens_to_ids
(
tokens
)
sep
=
[
self
.
_convert_token_to_id
(
self
.
sep_token
)]
return
(
cls
+
tokens
+
sep
)
if
(
isinstance
(
tokens
,
list
)
and
not
no_sep_cls_tokens
)
else
tokens
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
# Remove the first and last tokens which are cls and sep tokens
ids
=
ids
[
1
:
-
1
]
# If multi sentence, then split (multi sentence found by looking for two sequential sep tokens)
ids
=
[
list
(
map
(
int
,
example
.
split
(
' '
)))
for
example
in
' '
.
join
([
str
(
id
)
for
id
in
ids
]).
split
(
' 2 2 '
)]
return
super
().
convert_ids_to_tokens
(
ids
,
skip_special_tokens
=
skip_special_tokens
)[
1
:
-
1
]
if
len
(
ids
)
==
1
:
tokens
=
self
.
gpt2_tokenizer
.
convert_ids_to_tokens
(
list
(
map
(
lambda
id
:
int
(
self
.
dictionary
[
id
]),
ids
[
0
])))
else
:
tokens
=
[]
for
example
in
ids
:
tokens
+=
[
self
.
gpt2_tokenizer
.
convert_ids_to_tokens
(
list
(
map
(
lambda
id
:
int
(
self
.
dictionary
[
id
]),
example
)))]
return
tokens
def
save_vocabulary
(
self
,
save_directory
):
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
save_directory
))
return
dict_file
=
os
.
path
.
join
(
save_directory
,
DICT_FILES_NAMES
[
'dict_file'
])
with
open
(
dict_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
i
in
range
(
self
.
dictionary
.
nspecial
,
len
(
self
.
dictionary
.
count
)):
f
.
write
(
f
"
{
list
(
self
.
dictionary
.
indices
.
keys
())[
i
]
}
{
self
.
dictionary
.
count
[
i
]
}
\n
"
)
def
convert_tokens_to_ids
(
self
,
tokens
):
tokens
=
" "
.
join
(
str
(
x
)
for
x
in
self
.
gpt2_tokenizer
.
convert_tokens_to_ids
(
tokens
))
bpe_sentence
=
'<s> '
+
tokens
+
' </s>'
return
self
.
dictionary
.
encode_line
(
bpe_sentence
,
append_eos
=
False
)
vocab_files
=
self
.
gpt2_tokenizer
.
save_pretrained
(
save_directory
)
return
vocab_files
+
(
dict_file
,)
pytorch_transformers/tokenization_utils.py
View file @
6c41a8f5
...
...
@@ -495,7 +495,7 @@ class PreTrainedTokenizer(object):
"""
raise
NotImplementedError
def
convert_tokens_to_ids
(
self
,
tokens
):
def
convert_tokens_to_ids
(
self
,
tokens
,
**
kwargs
):
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
(resp. a sequence of ids), using the vocabulary.
"""
...
...
@@ -520,12 +520,29 @@ class PreTrainedTokenizer(object):
raise
NotImplementedError
def
encode
(
self
,
text
):
def
encode
(
self
,
*
text
,
cls_token_at_end
=
False
,
double_sep_token
=
False
,
no_sep_cls_tokens
=
False
):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
"""
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
if
len
(
text
)
==
1
:
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
[
0
]),
no_sep_cls_tokens
=
no_sep_cls_tokens
)
if
len
(
text
)
>
2
:
logger
.
warning
(
"Tokenization currently only supports sentence pairs. Ignoring every string following the "
"initial two."
)
first_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text
[
0
])]
second_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text
[
1
])]
sep
=
[
self
.
_convert_token_to_id
(
self
.
sep_token
)]
cls
=
[
self
.
_convert_token_to_id
(
self
.
cls_token
)]
n_sep_token
=
2
if
double_sep_token
else
1
tokens
=
first_sentence_tokens
+
sep
*
n_sep_token
+
second_sentence_tokens
+
sep
tokens
=
(
tokens
+
cls
)
if
cls_token_at_end
else
(
cls
+
tokens
)
return
tokens
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
...
...
@@ -560,7 +577,8 @@ class PreTrainedTokenizer(object):
"""
return
' '
.
join
(
self
.
convert_ids_to_tokens
(
tokens
))
def
decode
(
self
,
token_ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
def
decode
(
self
,
token_ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
,
cls_token_at_end
=
False
,
double_sep_token
=
False
):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
...
...
@@ -568,8 +586,20 @@ class PreTrainedTokenizer(object):
"""
filtered_tokens
=
self
.
convert_ids_to_tokens
(
token_ids
,
skip_special_tokens
=
skip_special_tokens
)
text
=
self
.
convert_tokens_to_string
(
filtered_tokens
)
if
self
.
sep_token
is
not
None
and
self
.
sep_token
in
text
:
text
=
text
.
replace
(
self
.
cls_token
,
self
.
sep_token
)
split_text
=
list
(
filter
(
lambda
sentence
:
len
(
sentence
)
>
0
,
text
.
split
(
self
.
sep_token
)))
if
clean_up_tokenization_spaces
:
text
=
self
.
clean_up_tokenization
(
text
)
clean_text
=
[
self
.
clean_up_tokenization
(
text
)
for
text
in
split_text
]
return
clean_text
else
:
return
split_text
else
:
if
clean_up_tokenization_spaces
:
clean_text
=
self
.
clean_up_tokenization
(
text
)
return
clean_text
else
:
return
text
@
property
...
...
@@ -602,7 +632,7 @@ class PreTrainedTokenizer(object):
class attributes (cls_token, unk_token...).
"""
all_toks
=
self
.
all_special_tokens
all_ids
=
list
(
self
.
convert_token
s
_to_id
s
(
t
)
for
t
in
all_toks
)
all_ids
=
list
(
self
.
_
convert_token_to_id
(
t
)
for
t
in
all_toks
)
return
all_ids
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment