Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
eaf8c8c3
Commit
eaf8c8c3
authored
Dec 14, 2020
by
Frederick Liu
Committed by
A. Unique TensorFlower
Dec 14, 2020
Browse files
Internal change
PiperOrigin-RevId: 347389234
parent
a9edf472
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
133 additions
and
0 deletions
+133
-0
official/nlp/data/train_sentencepiece.py
official/nlp/data/train_sentencepiece.py
+133
-0
No files found.
official/nlp/data/train_sentencepiece.py
0 → 100644
View file @
eaf8c8c3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to train sentencepiece model from tensorflow datasets.
Reserved tokens:
pad: 0,
eos: 1,
unk: 2
(bos is not reserved)
"""
import
os
import
tempfile
from
typing
import
List
,
Tuple
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
sentencepiece
import
SentencePieceTrainer
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"output_model_path"
,
None
,
"Path to save the the sentencepiece model."
)
flags
.
mark_flag_as_required
(
"output_model_path"
)
flags
.
DEFINE_string
(
"tfds_dir"
,
None
,
"Directory of the tfds."
)
flags
.
DEFINE_string
(
"tfds_name"
,
"wmt14_translate/de-en"
,
"Name of the dataset we generate vacabulay from."
)
flags
.
DEFINE_string
(
"tfds_split"
,
"train"
,
"Split of the dataset."
)
flags
.
DEFINE_integer
(
"vocab_size"
,
32000
,
"Size of vocabulary."
)
flags
.
DEFINE_integer
(
"max_char"
,
-
1
,
"Maximum number of characters to use. "
"If a non-positive number is provided, all sentences are used."
)
flags
.
DEFINE_string
(
"model_type"
,
"bpe"
,
"Model algorithm: unigram, bpe, word or char."
)
flags
.
DEFINE_float
(
"character_coverage"
,
0.9995
,
"Character coverage to determine the minimum symbols"
)
flags
.
DEFINE_list
(
"data_keys"
,
[
"en"
,
"de"
],
"Comma-separated list of keys to use for training the vocabulary."
)
def
dump_chars_to_textfile
(
dataset
:
tf
.
data
.
Dataset
,
data_keys
:
Tuple
[
str
],
max_char
:
int
=
-
1
):
"""Write part of a TFDS sentence dataset to lines in a text file.
Args:
dataset: tf.dataset containing string-data.
data_keys: what keys in dataset to dump from.
max_char: max character to dump to text file.
Returns:
name of temp file with dataset bytes, exact number of characters dumped.
"""
ds_iter
=
dataset
.
as_numpy_iterator
()
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
outfp
:
char_count
=
0
while
True
:
example
=
next
(
ds_iter
,
None
)
if
example
is
None
or
(
max_char
>
0
and
char_count
>
max_char
):
break
for
k
in
data_keys
:
line
=
example
[
k
]
+
b
"
\n
"
char_count
+=
len
(
line
)
outfp
.
write
(
line
)
return
outfp
.
name
def
train_sentencepiece
(
file_path
:
str
,
model_path
:
str
,
vocab_size
:
int
,
character_coverage
:
float
,
model_type
:
str
):
"""Train SentencePiece tokenizer from subset of tf dataset.
Args:
file_path: path of data to train sentencepiece.
model_path: path of model file to save vocab model to.
vocab_size: size of vocab tokens to train.
character_coverage: amount of characters covered by the model, good defaults
are 0.9995 for languages with rich character set like Japanese or Chinese
and 1.0 for other languages with small character set.
model_type: type of sentencepiece vocab to train.
Returns:
path to the trained sentencepiece vocabulary model.
"""
argstr
=
" "
.
join
([
f
"--input=
{
file_path
}
"
,
f
"--vocab_size=
{
vocab_size
}
"
,
f
"--character_coverage=
{
character_coverage
}
"
,
f
"--model_prefix=
{
model_path
}
"
,
f
"--model_type=
{
model_type
}
"
,
"--bos_id=-1"
,
"--pad_id=0"
,
"--eos_id=1"
,
"--unk_id=2"
])
SentencePieceTrainer
.
Train
(
argstr
)
def
main
(
argv
:
List
[
str
]):
del
argv
builder
=
tfds
.
builder
(
FLAGS
.
tfds_name
,
data_dir
=
FLAGS
.
tfds_dir
)
ds
=
builder
.
as_dataset
(
split
=
FLAGS
.
tfds_split
)
tmp_filename
=
dump_chars_to_textfile
(
ds
,
FLAGS
.
data_keys
,
FLAGS
.
max_char
)
logging
.
info
(
"Sentencepiece model will be placed here: %s"
,
FLAGS
.
output_model_path
)
train_sentencepiece
(
tmp_filename
,
FLAGS
.
output_model_path
,
FLAGS
.
vocab_size
,
FLAGS
.
character_coverage
,
FLAGS
.
model_type
)
os
.
remove
(
tmp_filename
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment