Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2416dd9c
Commit
2416dd9c
authored
Mar 19, 2020
by
A. Unique TensorFlower
Browse files
Merge pull request #8256 from stagedml:tokenizer-update
PiperOrigin-RevId: 301949311
parents
27207a27
30579e0f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
118 additions
and
54 deletions
+118
-54
official/nlp/transformer/utils/tokenizer.py
official/nlp/transformer/utils/tokenizer.py
+82
-40
official/nlp/transformer/utils/tokenizer_test.py
official/nlp/transformer/utils/tokenizer_test.py
+36
-14
No files found.
official/nlp/transformer/utils/tokenizer.py
View file @
2416dd9c
...
@@ -28,6 +28,8 @@ import six
...
@@ -28,6 +28,8 @@ import six
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-complex-comprehension
PAD
=
"<pad>"
PAD
=
"<pad>"
PAD_ID
=
0
PAD_ID
=
0
EOS
=
"<EOS>"
EOS
=
"<EOS>"
...
@@ -45,27 +47,36 @@ _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
...
@@ -45,27 +47,36 @@ _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
_UNDEFINED_UNICODE
=
u
"
\u3013
"
_UNDEFINED_UNICODE
=
u
"
\u3013
"
def
alphanumeric_char_set
():
return
set
(
six
.
unichr
(
i
)
for
i
in
xrange
(
sys
.
maxunicode
)
if
(
unicodedata
.
category
(
six
.
unichr
(
i
)).
startswith
(
"L"
)
or
unicodedata
.
category
(
six
.
unichr
(
i
)).
startswith
(
"N"
)))
# Set contains all letter and number characters.
# Set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET
=
set
(
_ALPHANUMERIC_CHAR_SET
=
alphanumeric_char_set
()
six
.
unichr
(
i
)
for
i
in
xrange
(
sys
.
maxunicode
)
if
(
unicodedata
.
category
(
six
.
unichr
(
i
)).
startswith
(
"L"
)
or
unicodedata
.
category
(
six
.
unichr
(
i
)).
startswith
(
"N"
)))
# min_count is the minimum number of times a subtoken must appear in the data
# min_count is the minimum number of times a subtoken must appear in the data
# before before it is added to the vocabulary. The value is found using binary
# before before it is added to the vocabulary. The value is found using binary
# search to obtain the target vocabulary size.
# search to obtain the target vocabulary size.
_MIN_MIN_COUNT
=
1
# min value to use when binary searching for min_count
_MIN_MIN_COUNT
=
1
# min value to use when binary searching for min_count
_MAX_MIN_COUNT
=
1000
# max value to use when binary searching for min_count
_MAX_MIN_COUNT
=
1000
# max value to use when binary searching for min_count
class
Subtokenizer
(
object
):
class
Subtokenizer
(
object
):
"""Encodes and decodes strings to/from integer IDs."""
"""Encodes and decodes strings to/from integer IDs."""
def
__init__
(
self
,
vocab_file
,
reserved_tokens
=
None
):
def
__init__
(
self
,
vocab_file
,
reserved_tokens
=
None
,
master_char_set
=
None
):
"""Initializes class, creating a vocab file if data_files is provided."""
"""Initializes class, creating a vocab file if data_files is provided."""
tf
.
compat
.
v1
.
logging
.
info
(
"Initializing Subtokenizer from file %s."
%
tf
.
compat
.
v1
.
logging
.
info
(
"Initializing Subtokenizer from file %s."
%
vocab_file
)
vocab_file
)
if
master_char_set
is
None
:
master_char_set
=
_ALPHANUMERIC_CHAR_SET
if
reserved_tokens
is
None
:
if
reserved_tokens
is
None
:
reserved_tokens
=
RESERVED_TOKENS
reserved_tokens
=
RESERVED_TOKENS
...
@@ -78,13 +89,20 @@ class Subtokenizer(object):
...
@@ -78,13 +89,20 @@ class Subtokenizer(object):
self
.
max_subtoken_length
=
max
(
self
.
max_subtoken_length
,
len
(
subtoken
))
self
.
max_subtoken_length
=
max
(
self
.
max_subtoken_length
,
len
(
subtoken
))
# Create cache to speed up subtokenization
# Create cache to speed up subtokenization
self
.
_cache_size
=
2
**
20
self
.
_cache_size
=
2
**
20
self
.
_cache
=
[(
None
,
None
)]
*
self
.
_cache_size
self
.
_cache
=
[(
None
,
None
)]
*
self
.
_cache_size
self
.
_master_char_set
=
master_char_set
@
staticmethod
@
staticmethod
def
init_from_files
(
def
init_from_files
(
vocab_file
,
vocab_file
,
files
,
target_vocab_size
,
threshold
,
min_count
=
None
,
files
,
file_byte_limit
=
1e6
,
reserved_tokens
=
None
,
correct_strip
=
True
):
target_vocab_size
,
threshold
,
min_count
=
None
,
file_byte_limit
=
1e6
,
reserved_tokens
=
None
,
correct_strip
=
True
,
master_char_set
=
None
):
"""Create subtoken vocabulary based on files, and save vocab to file.
"""Create subtoken vocabulary based on files, and save vocab to file.
Args:
Args:
...
@@ -101,10 +119,13 @@ class Subtokenizer(object):
...
@@ -101,10 +119,13 @@ class Subtokenizer(object):
reserved_tokens: List of string tokens that are guaranteed to be at the
reserved_tokens: List of string tokens that are guaranteed to be at the
beginning of the subtoken vocabulary list.
beginning of the subtoken vocabulary list.
correct_strip: Whether to convert text to unicode before strip.
correct_strip: Whether to convert text to unicode before strip.
master_char_set: the char set.
Returns:
Returns:
Subtokenizer object
Subtokenizer object
"""
"""
if
master_char_set
is
None
:
master_char_set
=
_ALPHANUMERIC_CHAR_SET
if
reserved_tokens
is
None
:
if
reserved_tokens
is
None
:
reserved_tokens
=
RESERVED_TOKENS
reserved_tokens
=
RESERVED_TOKENS
...
@@ -112,7 +133,8 @@ class Subtokenizer(object):
...
@@ -112,7 +133,8 @@ class Subtokenizer(object):
tf
.
compat
.
v1
.
logging
.
info
(
"Vocab file already exists (%s)"
%
vocab_file
)
tf
.
compat
.
v1
.
logging
.
info
(
"Vocab file already exists (%s)"
%
vocab_file
)
else
:
else
:
tf
.
compat
.
v1
.
logging
.
info
(
"Begin steps to create subtoken vocabulary..."
)
tf
.
compat
.
v1
.
logging
.
info
(
"Begin steps to create subtoken vocabulary..."
)
token_counts
=
_count_tokens
(
files
,
file_byte_limit
,
correct_strip
)
token_counts
=
_count_tokens
(
files
,
file_byte_limit
,
correct_strip
,
master_char_set
)
alphabet
=
_generate_alphabet_dict
(
token_counts
)
alphabet
=
_generate_alphabet_dict
(
token_counts
)
subtoken_list
=
_generate_subtokens_with_target_vocab_size
(
subtoken_list
=
_generate_subtokens_with_target_vocab_size
(
token_counts
,
alphabet
,
target_vocab_size
,
threshold
,
min_count
,
token_counts
,
alphabet
,
target_vocab_size
,
threshold
,
min_count
,
...
@@ -120,15 +142,18 @@ class Subtokenizer(object):
...
@@ -120,15 +142,18 @@ class Subtokenizer(object):
tf
.
compat
.
v1
.
logging
.
info
(
"Generated vocabulary with %d subtokens."
%
tf
.
compat
.
v1
.
logging
.
info
(
"Generated vocabulary with %d subtokens."
%
len
(
subtoken_list
))
len
(
subtoken_list
))
_save_vocab_file
(
vocab_file
,
subtoken_list
)
_save_vocab_file
(
vocab_file
,
subtoken_list
)
return
Subtokenizer
(
vocab_file
)
return
Subtokenizer
(
vocab_file
,
master_char_set
=
master_char_set
)
def
encode
(
self
,
raw_string
,
add_eos
=
False
):
def
encode
(
self
,
raw_string
,
add_eos
=
False
):
"""Encodes a string into a list of int subtoken ids."""
"""Encodes a string into a list of int subtoken ids."""
ret
=
[]
ret
=
[]
tokens
=
_split_string_to_tokens
(
native_to_unicode
(
raw_string
))
tokens
=
_split_string_to_tokens
(
native_to_unicode
(
raw_string
),
self
.
_master_char_set
)
for
token
in
tokens
:
for
token
in
tokens
:
ret
.
extend
(
self
.
_token_to_subtoken_ids
(
token
))
ret
.
extend
(
self
.
_token_to_subtoken_ids
(
token
))
if
add_eos
:
if
add_eos
:
assert
EOS
in
self
.
subtoken_list
,
\
"Can't append 'EOS' because it is not in list of known subtokens."
ret
.
append
(
EOS_ID
)
ret
.
append
(
EOS_ID
)
return
ret
return
ret
...
@@ -161,13 +186,14 @@ class Subtokenizer(object):
...
@@ -161,13 +186,14 @@ class Subtokenizer(object):
"Subtokens argument passed into decode() must be a list of integers."
)
"Subtokens argument passed into decode() must be a list of integers."
)
return
_unicode_to_native
(
return
_unicode_to_native
(
_join_tokens_to_string
(
self
.
_subtoken_ids_to_tokens
(
subtokens
)))
_join_tokens_to_string
(
self
.
_subtoken_ids_to_tokens
(
subtokens
),
self
.
_master_char_set
))
def
_subtoken_ids_to_tokens
(
self
,
subtokens
):
def
_subtoken_ids_to_tokens
(
self
,
subtokens
):
"""Convert list of int subtoken ids to a list of string tokens."""
"""Convert list of int subtoken ids to a list of string tokens."""
escaped_tokens
=
""
.
join
([
escaped_tokens
=
""
.
join
([
self
.
subtoken_list
[
s
]
for
s
in
subtokens
self
.
subtoken_list
[
s
]
for
s
in
subtokens
if
s
<
len
(
self
.
subtoken_list
)
if
s
<
len
(
self
.
subtoken_list
)
])
])
escaped_tokens
=
escaped_tokens
.
split
(
"_"
)
escaped_tokens
=
escaped_tokens
.
split
(
"_"
)
# All tokens in the vocabulary list have been escaped (see _escape_token())
# All tokens in the vocabulary list have been escaped (see _escape_token())
...
@@ -204,7 +230,7 @@ def _load_vocab_file(vocab_file, reserved_tokens=None):
...
@@ -204,7 +230,7 @@ def _load_vocab_file(vocab_file, reserved_tokens=None):
def
native_to_unicode
(
s
):
def
native_to_unicode
(
s
):
"""Convert string to unicode (required in Python 2)."""
"""Convert string to unicode (required in Python 2)."""
try
:
# Python 2
try
:
# Python 2
return
s
if
isinstance
(
s
,
unicode
)
else
s
.
decode
(
"utf-8"
)
return
s
if
isinstance
(
s
,
unicode
)
else
s
.
decode
(
"utf-8"
)
except
NameError
:
# Python 3
except
NameError
:
# Python 3
return
s
return
s
...
@@ -212,22 +238,22 @@ def native_to_unicode(s):
...
@@ -212,22 +238,22 @@ def native_to_unicode(s):
def
_unicode_to_native
(
s
):
def
_unicode_to_native
(
s
):
"""Convert string from unicode to native format (required in Python 2)."""
"""Convert string from unicode to native format (required in Python 2)."""
try
:
# Python 2
try
:
# Python 2
return
s
.
encode
(
"utf-8"
)
if
isinstance
(
s
,
unicode
)
else
s
return
s
.
encode
(
"utf-8"
)
if
isinstance
(
s
,
unicode
)
else
s
except
NameError
:
# Python 3
except
NameError
:
# Python 3
return
s
return
s
def
_split_string_to_tokens
(
text
):
def
_split_string_to_tokens
(
text
,
master_char_set
):
"""Splits text to a list of string tokens."""
"""Splits text to a list of string tokens."""
if
not
text
:
if
not
text
:
return
[]
return
[]
ret
=
[]
ret
=
[]
token_start
=
0
token_start
=
0
# Classify each character in the input string
# Classify each character in the input string
is_
alnum
=
[
c
in
_ALPHANUMERIC_CHAR_SET
for
c
in
text
]
is_
master
=
[
c
in
master_char_set
for
c
in
text
]
for
pos
in
xrange
(
1
,
len
(
text
)):
for
pos
in
xrange
(
1
,
len
(
text
)):
if
is_
alnum
[
pos
]
!=
is_
alnum
[
pos
-
1
]:
if
is_
master
[
pos
]
!=
is_
master
[
pos
-
1
]:
token
=
text
[
token_start
:
pos
]
token
=
text
[
token_start
:
pos
]
if
token
!=
u
" "
or
token_start
==
0
:
if
token
!=
u
" "
or
token_start
==
0
:
ret
.
append
(
token
)
ret
.
append
(
token
)
...
@@ -237,12 +263,12 @@ def _split_string_to_tokens(text):
...
@@ -237,12 +263,12 @@ def _split_string_to_tokens(text):
return
ret
return
ret
def
_join_tokens_to_string
(
tokens
):
def
_join_tokens_to_string
(
tokens
,
master_char_set
):
"""Join a list of string tokens into a single string."""
"""Join a list of string tokens into a single string."""
token_is_
alnum
=
[
t
[
0
]
in
_ALPHANUMERIC_CHAR_SET
for
t
in
tokens
]
token_is_
master
=
[
t
[
0
]
in
master_char_set
for
t
in
tokens
]
ret
=
[]
ret
=
[]
for
i
,
token
in
enumerate
(
tokens
):
for
i
,
token
in
enumerate
(
tokens
):
if
i
>
0
and
token_is_
alnum
[
i
-
1
]
and
token_is_
alnum
[
i
]:
if
i
>
0
and
token_is_
master
[
i
-
1
]
and
token_is_
master
[
i
]:
ret
.
append
(
u
" "
)
ret
.
append
(
u
" "
)
ret
.
append
(
token
)
ret
.
append
(
token
)
return
""
.
join
(
ret
)
return
""
.
join
(
ret
)
...
@@ -324,7 +350,10 @@ def _unescape_token(token):
...
@@ -324,7 +350,10 @@ def _unescape_token(token):
return
_UNESCAPE_REGEX
.
sub
(
match
,
token
)
return
_UNESCAPE_REGEX
.
sub
(
match
,
token
)
def
_count_tokens
(
files
,
file_byte_limit
=
1e6
,
correct_strip
=
True
):
def
_count_tokens
(
files
,
file_byte_limit
=
1e6
,
correct_strip
=
True
,
master_char_set
=
None
):
"""Return token counts of words in the files.
"""Return token counts of words in the files.
Samples file_byte_limit bytes from each file, and counts the words that appear
Samples file_byte_limit bytes from each file, and counts the words that appear
...
@@ -337,11 +366,15 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True):
...
@@ -337,11 +366,15 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True):
vocabulary generation for PY2. Sets correct_strip to False in PY2 to
vocabulary generation for PY2. Sets correct_strip to False in PY2 to
reproduce previous common public result. Sets correct_strip to True will
reproduce previous common public result. Sets correct_strip to True will
let PY2 and PY3 get a consistent vocabulary.
let PY2 and PY3 get a consistent vocabulary.
master_char_set: the char set.
Returns:
Returns:
Dictionary mapping tokens to the number of times they appear in the sampled
Dictionary mapping tokens to the number of times they appear in the sampled
lines from the files.
lines from the files.
"""
"""
if
master_char_set
is
None
:
master_char_set
=
_ALPHANUMERIC_CHAR_SET
token_counts
=
collections
.
defaultdict
(
int
)
token_counts
=
collections
.
defaultdict
(
int
)
for
filepath
in
files
:
for
filepath
in
files
:
...
@@ -362,7 +395,8 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True):
...
@@ -362,7 +395,8 @@ def _count_tokens(files, file_byte_limit=1e6, correct_strip=True):
counter
=
0
counter
=
0
# Add words to token counts
# Add words to token counts
for
token
in
_split_string_to_tokens
(
native_to_unicode
(
line
)):
for
token
in
_split_string_to_tokens
(
native_to_unicode
(
line
),
master_char_set
):
token_counts
[
token
]
+=
1
token_counts
[
token
]
+=
1
return
token_counts
return
token_counts
...
@@ -394,9 +428,12 @@ def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length):
...
@@ -394,9 +428,12 @@ def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length):
return
ret
return
ret
def
_generate_subtokens_with_target_vocab_size
(
def
_generate_subtokens_with_target_vocab_size
(
token_counts
,
token_counts
,
alphabet
,
target_size
,
threshold
,
min_count
=
None
,
alphabet
,
reserved_tokens
=
None
):
target_size
,
threshold
,
min_count
=
None
,
reserved_tokens
=
None
):
"""Generate subtoken vocabulary close to the target size."""
"""Generate subtoken vocabulary close to the target size."""
if
reserved_tokens
is
None
:
if
reserved_tokens
is
None
:
reserved_tokens
=
RESERVED_TOKENS
reserved_tokens
=
RESERVED_TOKENS
...
@@ -449,8 +486,8 @@ def _generate_alphabet_dict(iterable, reserved_tokens=None):
...
@@ -449,8 +486,8 @@ def _generate_alphabet_dict(iterable, reserved_tokens=None):
return
alphabet
return
alphabet
def
_count_and_gen_subtokens
(
def
_count_and_gen_subtokens
(
token_counts
,
alphabet
,
subtoken_dict
,
token_counts
,
alphabet
,
subtoken_dict
,
max_subtoken_length
):
max_subtoken_length
):
"""Count number of times subtokens appear, and generate new subtokens.
"""Count number of times subtokens appear, and generate new subtokens.
Args:
Args:
...
@@ -468,8 +505,8 @@ def _count_and_gen_subtokens(
...
@@ -468,8 +505,8 @@ def _count_and_gen_subtokens(
subtoken_counts
=
collections
.
defaultdict
(
int
)
subtoken_counts
=
collections
.
defaultdict
(
int
)
for
token
,
count
in
six
.
iteritems
(
token_counts
):
for
token
,
count
in
six
.
iteritems
(
token_counts
):
token
=
_escape_token
(
token
,
alphabet
)
token
=
_escape_token
(
token
,
alphabet
)
subtokens
=
_split_token_to_subtokens
(
subtokens
=
_split_token_to_subtokens
(
token
,
subtoken_dict
,
token
,
subtoken_dict
,
max_subtoken_length
)
max_subtoken_length
)
# Generate new subtokens by taking substrings from token.
# Generate new subtokens by taking substrings from token.
start
=
0
start
=
0
...
@@ -503,8 +540,10 @@ def _filter_and_bucket_subtokens(subtoken_counts, min_count):
...
@@ -503,8 +540,10 @@ def _filter_and_bucket_subtokens(subtoken_counts, min_count):
return
subtoken_buckets
return
subtoken_buckets
def
_gen_new_subtoken_list
(
def
_gen_new_subtoken_list
(
subtoken_counts
,
subtoken_counts
,
min_count
,
alphabet
,
reserved_tokens
=
None
):
min_count
,
alphabet
,
reserved_tokens
=
None
):
"""Generate candidate subtokens ordered by count, and new max subtoken length.
"""Generate candidate subtokens ordered by count, and new max subtoken length.
Add subtokens to the candiate list in order of length (longest subtokens
Add subtokens to the candiate list in order of length (longest subtokens
...
@@ -575,9 +614,11 @@ def _gen_new_subtoken_list(
...
@@ -575,9 +614,11 @@ def _gen_new_subtoken_list(
return
subtoken_list
,
max_subtoken_length
return
subtoken_list
,
max_subtoken_length
def
_generate_subtokens
(
def
_generate_subtokens
(
token_counts
,
token_counts
,
alphabet
,
min_count
,
num_iterations
=
4
,
alphabet
,
reserved_tokens
=
None
):
min_count
,
num_iterations
=
4
,
reserved_tokens
=
None
):
"""Create a list of subtokens in decreasing order of frequency.
"""Create a list of subtokens in decreasing order of frequency.
Args:
Args:
...
@@ -609,8 +650,9 @@ def _generate_subtokens(
...
@@ -609,8 +650,9 @@ def _generate_subtokens(
# Create dict mapping subtoken->count, with additional subtokens created
# Create dict mapping subtoken->count, with additional subtokens created
# from substrings taken from the tokens.
# from substrings taken from the tokens.
subtoken_counts
=
_count_and_gen_subtokens
(
subtoken_counts
=
_count_and_gen_subtokens
(
token_counts
,
alphabet
,
token_counts
,
alphabet
,
subtoken_dict
,
max_subtoken_length
)
subtoken_dict
,
max_subtoken_length
)
# Generate new list of subtokens sorted by subtoken count.
# Generate new list of subtokens sorted by subtoken count.
subtoken_list
,
max_subtoken_length
=
_gen_new_subtoken_list
(
subtoken_list
,
max_subtoken_length
=
_gen_new_subtoken_list
(
...
...
official/nlp/transformer/utils/tokenizer_test.py
View file @
2416dd9c
...
@@ -59,13 +59,15 @@ class StringHelperTest(tf.test.TestCase):
...
@@ -59,13 +59,15 @@ class StringHelperTest(tf.test.TestCase):
def
test_split_string_to_tokens
(
self
):
def
test_split_string_to_tokens
(
self
):
text
=
"test? testing 123."
text
=
"test? testing 123."
tokens
=
tokenizer
.
_split_string_to_tokens
(
text
)
tokens
=
tokenizer
.
_split_string_to_tokens
(
text
,
tokenizer
.
_ALPHANUMERIC_CHAR_SET
)
self
.
assertEqual
([
"test"
,
"? "
,
"testing"
,
"123"
,
"."
],
tokens
)
self
.
assertEqual
([
"test"
,
"? "
,
"testing"
,
"123"
,
"."
],
tokens
)
def
test_join_tokens_to_string
(
self
):
def
test_join_tokens_to_string
(
self
):
tokens
=
[
"test"
,
"? "
,
"testing"
,
"123"
,
"."
]
tokens
=
[
"test"
,
"? "
,
"testing"
,
"123"
,
"."
]
s
=
tokenizer
.
_join_tokens_to_string
(
tokens
)
s
=
tokenizer
.
_join_tokens_to_string
(
tokens
,
tokenizer
.
_ALPHANUMERIC_CHAR_SET
)
self
.
assertEqual
(
"test? testing 123."
,
s
)
self
.
assertEqual
(
"test? testing 123."
,
s
)
def
test_escape_token
(
self
):
def
test_escape_token
(
self
):
...
@@ -79,8 +81,7 @@ class StringHelperTest(tf.test.TestCase):
...
@@ -79,8 +81,7 @@ class StringHelperTest(tf.test.TestCase):
escaped_token
=
u
"Underline:
\\
u, Backslash:
\\\\
, Unicode:
\\
52;"
escaped_token
=
u
"Underline:
\\
u, Backslash:
\\\\
, Unicode:
\\
52;"
unescaped_token
=
tokenizer
.
_unescape_token
(
escaped_token
)
unescaped_token
=
tokenizer
.
_unescape_token
(
escaped_token
)
self
.
assertEqual
(
self
.
assertEqual
(
"Underline: _, Backslash:
\\
, Unicode: 4"
,
unescaped_token
)
"Underline: _, Backslash:
\\
, Unicode: 4"
,
unescaped_token
)
def
test_list_to_index_dict
(
self
):
def
test_list_to_index_dict
(
self
):
lst
=
[
"test"
,
"strings"
]
lst
=
[
"test"
,
"strings"
]
...
@@ -93,8 +94,8 @@ class StringHelperTest(tf.test.TestCase):
...
@@ -93,8 +94,8 @@ class StringHelperTest(tf.test.TestCase):
subtoken_dict
=
{
"a"
:
0
,
"b"
:
1
,
"c"
:
2
,
"ab"
:
3
}
subtoken_dict
=
{
"a"
:
0
,
"b"
:
1
,
"c"
:
2
,
"ab"
:
3
}
max_subtoken_length
=
2
max_subtoken_length
=
2
subtokens
=
tokenizer
.
_split_token_to_subtokens
(
subtokens
=
tokenizer
.
_split_token_to_subtokens
(
token
,
subtoken_dict
,
token
,
subtoken_dict
,
max_subtoken_length
)
max_subtoken_length
)
self
.
assertEqual
([
"ab"
,
"c"
],
subtokens
)
self
.
assertEqual
([
"ab"
,
"c"
],
subtokens
)
def
test_generate_alphabet_dict
(
self
):
def
test_generate_alphabet_dict
(
self
):
...
@@ -124,12 +125,28 @@ class StringHelperTest(tf.test.TestCase):
...
@@ -124,12 +125,28 @@ class StringHelperTest(tf.test.TestCase):
self
.
assertIsInstance
(
subtoken_counts
,
collections
.
defaultdict
)
self
.
assertIsInstance
(
subtoken_counts
,
collections
.
defaultdict
)
self
.
assertDictEqual
(
self
.
assertDictEqual
(
{
"a"
:
5
,
"b"
:
5
,
"c"
:
5
,
"_"
:
5
,
"ab"
:
5
,
"bc"
:
5
,
"c_"
:
5
,
{
"abc"
:
5
,
"bc_"
:
5
,
"abc_"
:
5
},
subtoken_counts
)
"a"
:
5
,
"b"
:
5
,
"c"
:
5
,
"_"
:
5
,
"ab"
:
5
,
"bc"
:
5
,
"c_"
:
5
,
"abc"
:
5
,
"bc_"
:
5
,
"abc_"
:
5
},
subtoken_counts
)
def
test_filter_and_bucket_subtokens
(
self
):
def
test_filter_and_bucket_subtokens
(
self
):
subtoken_counts
=
collections
.
defaultdict
(
subtoken_counts
=
collections
.
defaultdict
(
int
,
{
int
,
{
"a"
:
2
,
"b"
:
4
,
"c"
:
1
,
"ab"
:
6
,
"ac"
:
3
,
"abbc"
:
5
})
"a"
:
2
,
"b"
:
4
,
"c"
:
1
,
"ab"
:
6
,
"ac"
:
3
,
"abbc"
:
5
})
min_count
=
3
min_count
=
3
subtoken_buckets
=
tokenizer
.
_filter_and_bucket_subtokens
(
subtoken_buckets
=
tokenizer
.
_filter_and_bucket_subtokens
(
...
@@ -142,8 +159,12 @@ class StringHelperTest(tf.test.TestCase):
...
@@ -142,8 +159,12 @@ class StringHelperTest(tf.test.TestCase):
self
.
assertEqual
(
set
([
"abbc"
]),
subtoken_buckets
[
4
])
self
.
assertEqual
(
set
([
"abbc"
]),
subtoken_buckets
[
4
])
def
test_gen_new_subtoken_list
(
self
):
def
test_gen_new_subtoken_list
(
self
):
subtoken_counts
=
collections
.
defaultdict
(
subtoken_counts
=
collections
.
defaultdict
(
int
,
{
int
,
{
"translate"
:
10
,
"t"
:
40
,
"tr"
:
16
,
"tra"
:
12
})
"translate"
:
10
,
"t"
:
40
,
"tr"
:
16
,
"tra"
:
12
})
min_count
=
5
min_count
=
5
alphabet
=
set
(
"translate"
)
alphabet
=
set
(
"translate"
)
reserved_tokens
=
[
"reserved"
,
"tokens"
]
reserved_tokens
=
[
"reserved"
,
"tokens"
]
...
@@ -167,8 +188,9 @@ class StringHelperTest(tf.test.TestCase):
...
@@ -167,8 +188,9 @@ class StringHelperTest(tf.test.TestCase):
num_iterations
=
1
num_iterations
=
1
reserved_tokens
=
[
"reserved"
,
"tokens"
]
reserved_tokens
=
[
"reserved"
,
"tokens"
]
vocab_list
=
tokenizer
.
_generate_subtokens
(
vocab_list
=
tokenizer
.
_generate_subtokens
(
token_counts
,
alphabet
,
token_counts
,
alphabet
,
min_count
,
num_iterations
,
reserved_tokens
)
min_count
,
num_iterations
,
reserved_tokens
)
# Check that reserved tokens are at the front of the list
# Check that reserved tokens are at the front of the list
self
.
assertEqual
(
vocab_list
[:
2
],
reserved_tokens
)
self
.
assertEqual
(
vocab_list
[:
2
],
reserved_tokens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment