Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
999fae62
Commit
999fae62
authored
Aug 12, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 12, 2020
Browse files
Internal change
PiperOrigin-RevId: 326286926
parent
94561082
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
251 additions
and
179 deletions
+251
-179
official/nlp/nhnet/optimizer.py
official/nlp/nhnet/optimizer.py
+2
-4
official/nlp/nhnet/raw_data_process.py
official/nlp/nhnet/raw_data_process.py
+1
-0
official/nlp/nhnet/raw_data_processor.py
official/nlp/nhnet/raw_data_processor.py
+6
-4
official/nlp/nhnet/trainer.py
official/nlp/nhnet/trainer.py
+1
-0
official/nlp/nhnet/utils.py
official/nlp/nhnet/utils.py
+2
-0
official/nlp/tasks/electra_task.py
official/nlp/tasks/electra_task.py
+1
-0
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+1
-0
official/nlp/tasks/masked_lm_test.py
official/nlp/tasks/masked_lm_test.py
+1
-2
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+30
-27
official/nlp/tasks/question_answering_test.py
official/nlp/tasks/question_answering_test.py
+1
-0
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+2
-2
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+1
-2
official/nlp/tasks/tagging_test.py
official/nlp/tasks/tagging_test.py
+1
-0
official/nlp/tasks/utils.py
official/nlp/tasks/utils.py
+2
-3
official/nlp/transformer/attention_layer.py
official/nlp/transformer/attention_layer.py
+17
-8
official/nlp/transformer/beam_search_v1.py
official/nlp/transformer/beam_search_v1.py
+20
-17
official/nlp/transformer/data_download.py
official/nlp/transformer/data_download.py
+58
-50
official/nlp/transformer/data_pipeline.py
official/nlp/transformer/data_pipeline.py
+55
-37
official/nlp/transformer/embedding_layer.py
official/nlp/transformer/embedding_layer.py
+3
-1
official/nlp/transformer/misc.py
official/nlp/transformer/misc.py
+46
-22
No files found.
official/nlp/nhnet/optimizer.py
View file @
999fae62
...
@@ -71,10 +71,8 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
...
@@ -71,10 +71,8 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def
create_optimizer
(
params
:
params_dict
.
ParamsDict
):
def
create_optimizer
(
params
:
params_dict
.
ParamsDict
):
"""Creates optimizer."""
"""Creates optimizer."""
lr_schedule
=
LearningRateSchedule
(
lr_schedule
=
LearningRateSchedule
(
params
.
learning_rate
,
params
.
hidden_size
,
params
.
learning_rate
,
params
.
learning_rate_warmup_steps
)
params
.
hidden_size
,
params
.
learning_rate_warmup_steps
)
return
tf
.
keras
.
optimizers
.
Adam
(
return
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
lr_schedule
,
learning_rate
=
lr_schedule
,
beta_1
=
params
.
adam_beta1
,
beta_1
=
params
.
adam_beta1
,
...
...
official/nlp/nhnet/raw_data_process.py
View file @
999fae62
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
"""Processes crawled content from news URLs by generating tfrecords."""
"""Processes crawled content from news URLs by generating tfrecords."""
import
os
import
os
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
official.nlp.nhnet
import
raw_data_processor
from
official.nlp.nhnet
import
raw_data_processor
...
...
official/nlp/nhnet/raw_data_processor.py
View file @
999fae62
...
@@ -20,6 +20,7 @@ import json
...
@@ -20,6 +20,7 @@ import json
import
multiprocessing
import
multiprocessing
import
os
import
os
import
urllib.parse
import
urllib.parse
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.bert
import
tokenization
from
official.nlp.bert
import
tokenization
...
@@ -47,10 +48,10 @@ class RawDataProcessor(object):
...
@@ -47,10 +48,10 @@ class RawDataProcessor(object):
max_num_articles: Maximum number of articles in a story.
max_num_articles: Maximum number of articles in a story.
include_article_title_in_passage: Whether to include article title in
include_article_title_in_passage: Whether to include article title in
article passage.
article passage.
include_text_snippet_in_example: Whether to include text snippet
include_text_snippet_in_example: Whether to include text snippet
(headline
(headline
and article content) in generated tensorflow Examples, for
and article content) in generated tensorflow Examples, for
debug usage.
debug usage.
If include_article_title_in_passage=True, title and body
If include_article_title_in_passage=True, title and body
will be
will be
separated by [SEP].
separated by [SEP].
"""
"""
self
.
articles
=
dict
()
self
.
articles
=
dict
()
self
.
tokenizer
=
tokenization
.
FullTokenizer
(
self
.
tokenizer
=
tokenization
.
FullTokenizer
(
...
@@ -156,6 +157,7 @@ class RawDataProcessor(object):
...
@@ -156,6 +157,7 @@ class RawDataProcessor(object):
def
_get_single_story_features
(
self
,
story_headline
,
articles
):
def
_get_single_story_features
(
self
,
story_headline
,
articles
):
"""Converts a list of articles to a tensorflow Example."""
"""Converts a list of articles to a tensorflow Example."""
def
get_text_snippet
(
article
):
def
get_text_snippet
(
article
):
if
article
.
text_b
:
if
article
.
text_b
:
return
" [SEP] "
.
join
([
article
.
text_a
,
article
.
text_b
])
return
" [SEP] "
.
join
([
article
.
text_a
,
article
.
text_b
])
...
...
official/nlp/nhnet/trainer.py
View file @
999fae62
...
@@ -21,6 +21,7 @@ from __future__ import print_function
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
os
import
os
# Import libraries
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
...
...
official/nlp/nhnet/utils.py
View file @
999fae62
...
@@ -44,6 +44,8 @@ def encoder_common_layers(transformer_block):
...
@@ -44,6 +44,8 @@ def encoder_common_layers(transformer_block):
transformer_block
.
_intermediate_dense
,
transformer_block
.
_output_dense
,
transformer_block
.
_intermediate_dense
,
transformer_block
.
_output_dense
,
transformer_block
.
_output_layer_norm
transformer_block
.
_output_layer_norm
]
]
# pylint: enable=protected-access
# pylint: enable=protected-access
...
...
official/nlp/tasks/electra_task.py
View file @
999fae62
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection)."""
"""ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection)."""
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/nlp/tasks/masked_lm.py
View file @
999fae62
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Masked language task."""
"""Masked language task."""
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/nlp/tasks/masked_lm_test.py
View file @
999fae62
...
@@ -52,8 +52,7 @@ class MLMTaskTest(tf.test.TestCase):
...
@@ -52,8 +52,7 @@ class MLMTaskTest(tf.test.TestCase):
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
# Saves a checkpoint.
# Saves a checkpoint.
ckpt
=
tf
.
train
.
Checkpoint
(
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
,
**
model
.
checkpoint_items
)
model
=
model
,
**
model
.
checkpoint_items
)
ckpt
.
save
(
config
.
init_checkpoint
)
ckpt
.
save
(
config
.
init_checkpoint
)
task
.
initialize
(
model
)
task
.
initialize
(
model
)
...
...
official/nlp/tasks/question_answering.py
View file @
999fae62
...
@@ -111,9 +111,7 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -111,9 +111,7 @@ class QuestionAnsweringTask(base_task.Task):
tf
.
cast
(
start_logits
,
dtype
=
tf
.
float32
),
tf
.
cast
(
start_logits
,
dtype
=
tf
.
float32
),
from_logits
=
True
)
from_logits
=
True
)
end_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
end_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
end_positions
,
end_positions
,
tf
.
cast
(
end_logits
,
dtype
=
tf
.
float32
),
from_logits
=
True
)
tf
.
cast
(
end_logits
,
dtype
=
tf
.
float32
),
from_logits
=
True
)
loss
=
(
tf
.
reduce_mean
(
start_loss
)
+
tf
.
reduce_mean
(
end_loss
))
/
2
loss
=
(
tf
.
reduce_mean
(
start_loss
)
+
tf
.
reduce_mean
(
end_loss
))
/
2
return
loss
return
loss
...
@@ -142,8 +140,7 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -142,8 +140,7 @@ class QuestionAnsweringTask(base_task.Task):
kwargs
=
dict
(
kwargs
=
dict
(
examples
=
eval_examples
,
examples
=
eval_examples
,
tokenizer
=
tokenization
.
FullTokenizer
(
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
params
.
vocab_file
,
vocab_file
=
params
.
vocab_file
,
do_lower_case
=
params
.
do_lower_case
),
do_lower_case
=
params
.
do_lower_case
),
max_seq_length
=
params
.
seq_length
,
max_seq_length
=
params
.
seq_length
,
doc_stride
=
params
.
doc_stride
,
doc_stride
=
params
.
doc_stride
,
max_query_length
=
params
.
query_length
,
max_query_length
=
params
.
query_length
,
...
@@ -192,8 +189,8 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -192,8 +189,8 @@ class QuestionAnsweringTask(base_task.Task):
input_path
=
self
.
_tf_record_input_path
input_path
=
self
.
_tf_record_input_path
dataloader_params
=
params
.
replace
(
input_path
=
input_path
)
dataloader_params
=
params
.
replace
(
input_path
=
input_path
)
return
data_loader_factory
.
get_data_loader
(
return
data_loader_factory
.
get_data_loader
(
dataloader_params
).
load
(
dataloader_params
).
load
(
input_context
)
input_context
)
def
build_metrics
(
self
,
training
=
None
):
def
build_metrics
(
self
,
training
=
None
):
del
training
del
training
...
@@ -209,16 +206,19 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -209,16 +206,19 @@ class QuestionAnsweringTask(base_task.Task):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
start_logits
,
end_logits
=
model_outputs
start_logits
,
end_logits
=
model_outputs
metrics
[
'start_position_accuracy'
].
update_state
(
metrics
[
'start_position_accuracy'
].
update_state
(
labels
[
'start_positions'
],
labels
[
'start_positions'
],
start_logits
)
start_logits
)
metrics
[
'end_position_accuracy'
].
update_state
(
metrics
[
'end_position_accuracy'
].
update_state
(
labels
[
'end_positions'
],
labels
[
'end_positions'
],
end_logits
)
end_logits
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
start_logits
,
end_logits
=
model_outputs
start_logits
,
end_logits
=
model_outputs
compiled_metrics
.
update_state
(
compiled_metrics
.
update_state
(
y_true
=
labels
,
# labels has keys 'start_positions' and 'end_positions'.
y_true
=
labels
,
# labels has keys 'start_positions' and 'end_positions'.
y_pred
=
{
'start_positions'
:
start_logits
,
'end_positions'
:
end_logits
})
y_pred
=
{
'start_positions'
:
start_logits
,
'end_positions'
:
end_logits
})
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
features
,
_
=
inputs
features
,
_
=
inputs
...
@@ -242,16 +242,16 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -242,16 +242,16 @@ class QuestionAnsweringTask(base_task.Task):
state
=
[]
state
=
[]
for
unique_ids
,
start_logits
,
end_logits
in
zip
(
for
unique_ids
,
start_logits
,
end_logits
in
zip
(
step_outputs
[
'unique_ids'
],
step_outputs
[
'unique_ids'
],
step_outputs
[
'start_logits'
],
step_outputs
[
'start_logits'
],
step_outputs
[
'end_logits'
]):
step_outputs
[
'end_logits'
]):
u_ids
,
s_logits
,
e_logits
=
(
u_ids
,
s_logits
,
e_logits
=
(
unique_ids
.
numpy
(),
start_logits
.
numpy
(),
unique_ids
.
numpy
(),
start_logits
.
numpy
(),
end_logits
.
numpy
())
end_logits
.
numpy
())
for
values
in
zip
(
u_ids
,
s_logits
,
e_logits
):
for
values
in
zip
(
u_ids
,
s_logits
,
e_logits
):
state
.
append
(
self
.
raw_aggregated_result
(
state
.
append
(
unique_id
=
values
[
0
],
self
.
raw_aggregated_result
(
start_logits
=
values
[
1
].
tolist
(),
unique_id
=
values
[
0
],
end_logits
=
values
[
2
].
tolist
()))
start_logits
=
values
[
1
].
tolist
(),
end_logits
=
values
[
2
].
tolist
()))
return
state
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
...
@@ -269,13 +269,13 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -269,13 +269,13 @@ class QuestionAnsweringTask(base_task.Task):
self
.
task_config
.
null_score_diff_threshold
),
self
.
task_config
.
null_score_diff_threshold
),
verbose
=
False
))
verbose
=
False
))
with
tf
.
io
.
gfile
.
GFile
(
with
tf
.
io
.
gfile
.
GFile
(
self
.
task_config
.
validation_data
.
input_path
,
self
.
task_config
.
validation_data
.
input_path
,
'r'
)
as
reader
:
'r'
)
as
reader
:
dataset_json
=
json
.
load
(
reader
)
dataset_json
=
json
.
load
(
reader
)
pred_dataset
=
dataset_json
[
'data'
]
pred_dataset
=
dataset_json
[
'data'
]
if
self
.
task_config
.
validation_data
.
version_2_with_negative
:
if
self
.
task_config
.
validation_data
.
version_2_with_negative
:
eval_metrics
=
squad_evaluate_v2_0
.
evaluate
(
eval_metrics
=
squad_evaluate_v2_0
.
evaluate
(
pred_dataset
,
all_predictions
,
pred_dataset
,
all_predictions
,
scores_diff
)
scores_diff
)
# Filter out useless metrics, such as start_position_accuracy that
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
# we did not actually compute.
eval_metrics
=
{
eval_metrics
=
{
...
@@ -284,13 +284,16 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -284,13 +284,16 @@ class QuestionAnsweringTask(base_task.Task):
'final_f1'
:
eval_metrics
[
'final_f1'
]
/
100.0
,
# scale back to [0, 1].
'final_f1'
:
eval_metrics
[
'final_f1'
]
/
100.0
,
# scale back to [0, 1].
'f1_threshold'
:
eval_metrics
[
'final_f1_thresh'
],
'f1_threshold'
:
eval_metrics
[
'final_f1_thresh'
],
'has_answer_exact_match'
:
eval_metrics
[
'HasAns_exact'
],
'has_answer_exact_match'
:
eval_metrics
[
'HasAns_exact'
],
'has_answer_f1'
:
eval_metrics
[
'HasAns_f1'
]}
'has_answer_f1'
:
eval_metrics
[
'HasAns_f1'
]
}
else
:
else
:
eval_metrics
=
squad_evaluate_v1_1
.
evaluate
(
pred_dataset
,
all_predictions
)
eval_metrics
=
squad_evaluate_v1_1
.
evaluate
(
pred_dataset
,
all_predictions
)
# Filter out useless metrics, such as start_position_accuracy that
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
# we did not actually compute.
eval_metrics
=
{
'exact_match'
:
eval_metrics
[
'exact_match'
],
eval_metrics
=
{
'final_f1'
:
eval_metrics
[
'final_f1'
]}
'exact_match'
:
eval_metrics
[
'exact_match'
],
'final_f1'
:
eval_metrics
[
'final_f1'
]
}
return
eval_metrics
return
eval_metrics
...
...
official/nlp/tasks/question_answering_test.py
View file @
999fae62
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
itertools
import
itertools
import
json
import
json
import
os
import
os
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/nlp/tasks/sentence_prediction.py
View file @
999fae62
...
@@ -35,7 +35,6 @@ from official.nlp.data import data_loader_factory
...
@@ -35,7 +35,6 @@ from official.nlp.data import data_loader_factory
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
models
from
official.nlp.tasks
import
utils
from
official.nlp.tasks
import
utils
METRIC_TYPES
=
frozenset
(
METRIC_TYPES
=
frozenset
(
[
'accuracy'
,
'matthews_corrcoef'
,
'pearson_spearman_corr'
])
[
'accuracy'
,
'matthews_corrcoef'
,
'pearson_spearman_corr'
])
...
@@ -137,7 +136,8 @@ class SentencePredictionTask(base_task.Task):
...
@@ -137,7 +136,8 @@ class SentencePredictionTask(base_task.Task):
metrics
=
[
tf
.
keras
.
metrics
.
MeanSquaredError
()]
metrics
=
[
tf
.
keras
.
metrics
.
MeanSquaredError
()]
else
:
else
:
metrics
=
[
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)]
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)
]
return
metrics
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
...
...
official/nlp/tasks/tagging.py
View file @
999fae62
...
@@ -250,8 +250,7 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
...
@@ -250,8 +250,7 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
cur_predict_ids
=
state
[
'predict_ids'
]
cur_predict_ids
=
state
[
'predict_ids'
]
cur_sentence_ids
=
state
[
'sentence_ids'
]
cur_sentence_ids
=
state
[
'sentence_ids'
]
for
batch_predict_ids
,
batch_label_mask
,
batch_sentence_ids
in
zip
(
for
batch_predict_ids
,
batch_label_mask
,
batch_sentence_ids
in
zip
(
outputs
[
'predict_ids'
],
outputs
[
'label_mask'
],
outputs
[
'predict_ids'
],
outputs
[
'label_mask'
],
outputs
[
'sentence_ids'
]):
outputs
[
'sentence_ids'
]):
for
tmp_predict_ids
,
tmp_label_mask
,
tmp_sentence_id
in
zip
(
for
tmp_predict_ids
,
tmp_label_mask
,
tmp_sentence_id
in
zip
(
batch_predict_ids
.
numpy
(),
batch_label_mask
.
numpy
(),
batch_predict_ids
.
numpy
(),
batch_label_mask
.
numpy
(),
batch_sentence_ids
.
numpy
()):
batch_sentence_ids
.
numpy
()):
...
...
official/nlp/tasks/tagging_test.py
View file @
999fae62
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
"""Tests for official.nlp.tasks.tagging."""
"""Tests for official.nlp.tasks.tagging."""
import
functools
import
functools
import
os
import
os
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/nlp/tasks/utils.py
View file @
999fae62
...
@@ -38,15 +38,14 @@ def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
...
@@ -38,15 +38,14 @@ def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
def
predict
(
predict_step_fn
:
Callable
[[
Any
],
Any
],
def
predict
(
predict_step_fn
:
Callable
[[
Any
],
Any
],
aggregate_fn
:
Callable
[[
Any
,
Any
],
Any
],
aggregate_fn
:
Callable
[[
Any
,
Any
],
Any
],
dataset
:
tf
.
data
.
Dataset
):
dataset
:
tf
.
data
.
Dataset
):
"""Runs prediction.
"""Runs prediction.
Args:
Args:
predict_step_fn: A callable such as `def predict_step(inputs)`, where
predict_step_fn: A callable such as `def predict_step(inputs)`, where
`inputs` are input tensors.
`inputs` are input tensors.
aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where
aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where
`value` is the outputs from `predict_step_fn`.
`value` is the outputs from `predict_step_fn`.
dataset: A `tf.data.Dataset` object.
dataset: A `tf.data.Dataset` object.
Returns:
Returns:
...
...
official/nlp/transformer/attention_layer.py
View file @
999fae62
...
@@ -88,7 +88,12 @@ class Attention(tf.keras.layers.Layer):
...
@@ -88,7 +88,12 @@ class Attention(tf.keras.layers.Layer):
"attention_dropout"
:
self
.
attention_dropout
,
"attention_dropout"
:
self
.
attention_dropout
,
}
}
def
call
(
self
,
query_input
,
source_input
,
bias
,
training
,
cache
=
None
,
def
call
(
self
,
query_input
,
source_input
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
decode_loop_step
=
None
):
"""Apply attention mechanism to query_input and source_input.
"""Apply attention mechanism to query_input and source_input.
...
@@ -102,9 +107,9 @@ class Attention(tf.keras.layers.Layer):
...
@@ -102,9 +107,9 @@ class Attention(tf.keras.layers.Layer):
cache: (Used during prediction) A dictionary with tensors containing
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, heads, dim_per_head],
{"k": tensor with shape [batch_size, i, heads, dim_per_head],
"v": tensor with shape [batch_size, i, heads, dim_per_head]}
"v": tensor with shape [batch_size, i, heads, dim_per_head]}
where
where
i is the current decoded length for non-padded decode, or max
i is the current decoded length for non-padded decode, or max
sequence length for padded decode.
sequence length for padded decode.
decode_loop_step: An integer, step number of the decoding loop. Used only
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
for autoregressive inference on TPU.
...
@@ -142,7 +147,7 @@ class Attention(tf.keras.layers.Layer):
...
@@ -142,7 +147,7 @@ class Attention(tf.keras.layers.Layer):
# Scale query to prevent the dot product between query and key from growing
# Scale query to prevent the dot product between query and key from growing
# too large.
# too large.
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
query
*=
depth
**
-
0.5
query
*=
depth
**-
0.5
# Calculate dot product attention
# Calculate dot product attention
logits
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key
,
query
)
logits
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key
,
query
)
...
@@ -164,7 +169,11 @@ class Attention(tf.keras.layers.Layer):
...
@@ -164,7 +169,11 @@ class Attention(tf.keras.layers.Layer):
class
SelfAttention
(
Attention
):
class
SelfAttention
(
Attention
):
"""Multiheaded self-attention layer."""
"""Multiheaded self-attention layer."""
def
call
(
self
,
query_input
,
bias
,
training
,
cache
=
None
,
def
call
(
self
,
query_input
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
decode_loop_step
=
None
):
return
super
(
SelfAttention
,
self
).
call
(
return
super
(
SelfAttention
,
self
).
call
(
query_input
,
query_input
,
bias
,
query_input
,
query_input
,
bias
,
training
,
cache
,
decode_loop_step
)
training
,
cache
,
decode_loop_step
)
official/nlp/transformer/beam_search_v1.py
View file @
999fae62
...
@@ -12,8 +12,7 @@
...
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Beam search to find the translated sequence with the highest probability.
"""Beam search to find the translated sequence with the highest probability."""
"""
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
from
official.nlp.modeling.ops
import
beam_search
from
official.nlp.modeling.ops
import
beam_search
...
@@ -41,23 +40,27 @@ class SequenceBeamSearch(beam_search.SequenceBeamSearch):
...
@@ -41,23 +40,27 @@ class SequenceBeamSearch(beam_search.SequenceBeamSearch):
return
finished_seq
,
finished_scores
return
finished_seq
,
finished_scores
def
sequence_beam_search
(
def
sequence_beam_search
(
symbols_to_logits_fn
,
symbols_to_logits_fn
,
initial_ids
,
initial_cache
,
vocab_size
,
beam_size
,
initial_ids
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
=
False
):
initial_cache
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
=
False
):
"""Search for sequence of subtoken ids with the largest probability.
"""Search for sequence of subtoken ids with the largest probability.
Args:
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
arguments. The passed in arguments will have shape: ids -> A tensor with
ids -> A tensor with shape [batch_size * beam_size, index].
shape [batch_size * beam_size, index]. index -> A scalar. cache -> A
index -> A scalar.
nested dictionary of tensors [batch_size * beam_size, ...].
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache: logits -> A
The function must return a tuple of logits and new cache:
tensor with shape [batch * beam_size, vocab_size]. new cache -> A nested
logits -> A tensor with shape [batch * beam_size, vocab_size].
dictionary with the same shape/structure as the inputted cache.
new cache -> A nested dictionary with the same shape/structure as the
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for each
inputted cache.
batch item.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
initial_cache: A dictionary, containing starting decoder variables
information.
information.
vocab_size: An integer, the size of the vocabulary, used for topk
vocab_size: An integer, the size of the vocabulary, used for topk
...
@@ -67,8 +70,8 @@ def sequence_beam_search(
...
@@ -67,8 +70,8 @@ def sequence_beam_search(
max_decode_length: An integer, the maximum length to decoded a sequence.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
padded_decode: A bool, indicating if max_sequence_length padding is used
for
for
beam search.
beam search.
Returns:
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
Top decoded sequences [batch_size, beam_size, max_decode_length]
...
...
official/nlp/transformer/data_download.py
View file @
999fae62
...
@@ -23,6 +23,7 @@ import random
...
@@ -23,6 +23,7 @@ import random
import
tarfile
import
tarfile
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
...
@@ -64,22 +65,18 @@ _TRAIN_DATA_SOURCES = [
...
@@ -64,22 +65,18 @@ _TRAIN_DATA_SOURCES = [
# Use pre-defined minimum count to generate subtoken vocabulary.
# Use pre-defined minimum count to generate subtoken vocabulary.
_TRAIN_DATA_MIN_COUNT
=
6
_TRAIN_DATA_MIN_COUNT
=
6
_EVAL_DATA_SOURCES
=
[
_EVAL_DATA_SOURCES
=
[{
{
"url"
:
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
,
"url"
:
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
,
"input"
:
"newstest2013.en"
,
"input"
:
"newstest2013.en"
,
"target"
:
"newstest2013.de"
,
"target"
:
"newstest2013.de"
,
}]
}
]
_TEST_DATA_SOURCES
=
[
_TEST_DATA_SOURCES
=
[{
{
"url"
:
(
"https://storage.googleapis.com/tf-perf-public/"
"url"
:
(
"https://storage.googleapis.com/tf-perf-public/"
"official_transformer/test_data/newstest2014.tgz"
),
"official_transformer/test_data/newstest2014.tgz"
),
"input"
:
"newstest2014.en"
,
"input"
:
"newstest2014.en"
,
"target"
:
"newstest2014.de"
,
"target"
:
"newstest2014.de"
,
}]
}
]
# Vocabulary constants
# Vocabulary constants
_TARGET_VOCAB_SIZE
=
32768
# Number of subtokens in the vocabulary list.
_TARGET_VOCAB_SIZE
=
32768
# Number of subtokens in the vocabulary list.
...
@@ -114,7 +111,9 @@ def find_file(path, filename, max_depth=5):
...
@@ -114,7 +111,9 @@ def find_file(path, filename, max_depth=5):
# Download and extraction functions
# Download and extraction functions
###############################################################################
###############################################################################
def
get_raw_files
(
raw_dir
,
data_source
):
def
get_raw_files
(
raw_dir
,
data_source
):
"""Return raw files from source. Downloads/extracts if needed.
"""Return raw files from source.
Downloads/extracts if needed.
Args:
Args:
raw_dir: string directory to store raw files
raw_dir: string directory to store raw files
...
@@ -134,8 +133,8 @@ def get_raw_files(raw_dir, data_source):
...
@@ -134,8 +133,8 @@ def get_raw_files(raw_dir, data_source):
"targets"
:
[],
"targets"
:
[],
}
# keys
}
# keys
for
d
in
data_source
:
for
d
in
data_source
:
input_file
,
target_file
=
download_and_extract
(
input_file
,
target_file
=
download_and_extract
(
raw_dir
,
d
[
"url"
],
raw_dir
,
d
[
"url"
],
d
[
"input"
],
d
[
"target"
])
d
[
"input"
],
d
[
"target"
])
raw_files
[
"inputs"
].
append
(
input_file
)
raw_files
[
"inputs"
].
append
(
input_file
)
raw_files
[
"targets"
].
append
(
target_file
)
raw_files
[
"targets"
].
append
(
target_file
)
return
raw_files
return
raw_files
...
@@ -167,7 +166,7 @@ def download_from_url(path, url):
...
@@ -167,7 +166,7 @@ def download_from_url(path, url):
found_file
=
find_file
(
path
,
filename
,
max_depth
=
0
)
found_file
=
find_file
(
path
,
filename
,
max_depth
=
0
)
if
found_file
is
None
:
if
found_file
is
None
:
filename
=
os
.
path
.
join
(
path
,
filename
)
filename
=
os
.
path
.
join
(
path
,
filename
)
logging
.
info
(
"Downloading from %s to %s."
%
(
url
,
filename
)
)
logging
.
info
(
"Downloading from %s to %s."
,
url
,
filename
)
inprogress_filepath
=
six
.
ensure_str
(
filename
)
+
".incomplete"
inprogress_filepath
=
six
.
ensure_str
(
filename
)
+
".incomplete"
inprogress_filepath
,
_
=
urllib
.
request
.
urlretrieve
(
inprogress_filepath
,
_
=
urllib
.
request
.
urlretrieve
(
url
,
inprogress_filepath
,
reporthook
=
download_report_hook
)
url
,
inprogress_filepath
,
reporthook
=
download_report_hook
)
...
@@ -176,7 +175,7 @@ def download_from_url(path, url):
...
@@ -176,7 +175,7 @@ def download_from_url(path, url):
tf
.
gfile
.
Rename
(
inprogress_filepath
,
filename
)
tf
.
gfile
.
Rename
(
inprogress_filepath
,
filename
)
return
filename
return
filename
else
:
else
:
logging
.
info
(
"Already downloaded: %s (at %s)."
%
(
url
,
found_file
)
)
logging
.
info
(
"Already downloaded: %s (at %s)."
,
url
,
found_file
)
return
found_file
return
found_file
...
@@ -199,14 +198,14 @@ def download_and_extract(path, url, input_filename, target_filename):
...
@@ -199,14 +198,14 @@ def download_and_extract(path, url, input_filename, target_filename):
input_file
=
find_file
(
path
,
input_filename
)
input_file
=
find_file
(
path
,
input_filename
)
target_file
=
find_file
(
path
,
target_filename
)
target_file
=
find_file
(
path
,
target_filename
)
if
input_file
and
target_file
:
if
input_file
and
target_file
:
logging
.
info
(
"Already downloaded and extracted %s."
%
url
)
logging
.
info
(
"Already downloaded and extracted %s."
,
url
)
return
input_file
,
target_file
return
input_file
,
target_file
# Download archive file if it doesn't already exist.
# Download archive file if it doesn't already exist.
compressed_file
=
download_from_url
(
path
,
url
)
compressed_file
=
download_from_url
(
path
,
url
)
# Extract compressed files
# Extract compressed files
logging
.
info
(
"Extracting %s."
%
compressed_file
)
logging
.
info
(
"Extracting %s."
,
compressed_file
)
with
tarfile
.
open
(
compressed_file
,
"r:gz"
)
as
corpus_tar
:
with
tarfile
.
open
(
compressed_file
,
"r:gz"
)
as
corpus_tar
:
corpus_tar
.
extractall
(
path
)
corpus_tar
.
extractall
(
path
)
...
@@ -236,13 +235,13 @@ def compile_files(raw_dir, raw_files, tag):
...
@@ -236,13 +235,13 @@ def compile_files(raw_dir, raw_files, tag):
raw_files: Dict containing filenames of input and target data.
raw_files: Dict containing filenames of input and target data.
{"inputs": list of files containing data in input language
{"inputs": list of files containing data in input language
"targets": list of files containing corresponding data in target language
"targets": list of files containing corresponding data in target language
}
}
tag: String to append to the compiled filename.
tag: String to append to the compiled filename.
Returns:
Returns:
Full path of compiled input and target files.
Full path of compiled input and target files.
"""
"""
logging
.
info
(
"Compiling files with tag %s."
%
tag
)
logging
.
info
(
"Compiling files with tag %s."
,
tag
)
filename
=
"%s-%s"
%
(
_PREFIX
,
tag
)
filename
=
"%s-%s"
%
(
_PREFIX
,
tag
)
input_compiled_file
=
os
.
path
.
join
(
raw_dir
,
input_compiled_file
=
os
.
path
.
join
(
raw_dir
,
six
.
ensure_str
(
filename
)
+
".lang1"
)
six
.
ensure_str
(
filename
)
+
".lang1"
)
...
@@ -255,7 +254,7 @@ def compile_files(raw_dir, raw_files, tag):
...
@@ -255,7 +254,7 @@ def compile_files(raw_dir, raw_files, tag):
input_file
=
raw_files
[
"inputs"
][
i
]
input_file
=
raw_files
[
"inputs"
][
i
]
target_file
=
raw_files
[
"targets"
][
i
]
target_file
=
raw_files
[
"targets"
][
i
]
logging
.
info
(
"Reading files %s and %s."
%
(
input_file
,
target_file
)
)
logging
.
info
(
"Reading files %s and %s."
,
input_file
,
target_file
)
write_file
(
input_writer
,
input_file
)
write_file
(
input_writer
,
input_file
)
write_file
(
target_writer
,
target_file
)
write_file
(
target_writer
,
target_file
)
return
input_compiled_file
,
target_compiled_file
return
input_compiled_file
,
target_compiled_file
...
@@ -271,8 +270,7 @@ def write_file(writer, filename):
...
@@ -271,8 +270,7 @@ def write_file(writer, filename):
###############################################################################
###############################################################################
# Data preprocessing
# Data preprocessing
###############################################################################
###############################################################################
def
encode_and_save_files
(
def
encode_and_save_files
(
subtokenizer
,
data_dir
,
raw_files
,
tag
,
total_shards
):
subtokenizer
,
data_dir
,
raw_files
,
tag
,
total_shards
):
"""Save data from files as encoded Examples in TFrecord format.
"""Save data from files as encoded Examples in TFrecord format.
Args:
Args:
...
@@ -287,14 +285,16 @@ def encode_and_save_files(
...
@@ -287,14 +285,16 @@ def encode_and_save_files(
List of all files produced.
List of all files produced.
"""
"""
# Create a file for each shard.
# Create a file for each shard.
filepaths
=
[
shard_filename
(
data_dir
,
tag
,
n
+
1
,
total_shards
)
filepaths
=
[
for
n
in
range
(
total_shards
)]
shard_filename
(
data_dir
,
tag
,
n
+
1
,
total_shards
)
for
n
in
range
(
total_shards
)
]
if
all_exist
(
filepaths
):
if
all_exist
(
filepaths
):
logging
.
info
(
"Files with tag %s already exist."
%
tag
)
logging
.
info
(
"Files with tag %s already exist."
,
tag
)
return
filepaths
return
filepaths
logging
.
info
(
"Saving files with tag %s."
%
tag
)
logging
.
info
(
"Saving files with tag %s."
,
tag
)
input_file
=
raw_files
[
0
]
input_file
=
raw_files
[
0
]
target_file
=
raw_files
[
1
]
target_file
=
raw_files
[
1
]
...
@@ -302,13 +302,14 @@ def encode_and_save_files(
...
@@ -302,13 +302,14 @@ def encode_and_save_files(
tmp_filepaths
=
[
six
.
ensure_str
(
fname
)
+
".incomplete"
for
fname
in
filepaths
]
tmp_filepaths
=
[
six
.
ensure_str
(
fname
)
+
".incomplete"
for
fname
in
filepaths
]
writers
=
[
tf
.
python_io
.
TFRecordWriter
(
fname
)
for
fname
in
tmp_filepaths
]
writers
=
[
tf
.
python_io
.
TFRecordWriter
(
fname
)
for
fname
in
tmp_filepaths
]
counter
,
shard
=
0
,
0
counter
,
shard
=
0
,
0
for
counter
,
(
input_line
,
target_line
)
in
enumerate
(
zip
(
for
counter
,
(
input_line
,
target_line
)
in
enumerate
(
txt_line_iterator
(
input_file
),
txt_line_iterator
(
target_file
))):
zip
(
txt_line_iterator
(
input_file
),
txt_line_iterator
(
target_file
))):
if
counter
>
0
and
counter
%
100000
==
0
:
if
counter
>
0
and
counter
%
100000
==
0
:
logging
.
info
(
"
\t
Saving case %d."
%
counter
)
logging
.
info
(
"
\t
Saving case %d."
,
counter
)
example
=
dict_to_example
(
example
=
dict_to_example
({
{
"inputs"
:
subtokenizer
.
encode
(
input_line
,
add_eos
=
True
),
"inputs"
:
subtokenizer
.
encode
(
input_line
,
add_eos
=
True
),
"targets"
:
subtokenizer
.
encode
(
target_line
,
add_eos
=
True
)})
"targets"
:
subtokenizer
.
encode
(
target_line
,
add_eos
=
True
)
})
writers
[
shard
].
write
(
example
.
SerializeToString
())
writers
[
shard
].
write
(
example
.
SerializeToString
())
shard
=
(
shard
+
1
)
%
total_shards
shard
=
(
shard
+
1
)
%
total_shards
for
writer
in
writers
:
for
writer
in
writers
:
...
@@ -329,7 +330,7 @@ def shard_filename(path, tag, shard_num, total_shards):
...
@@ -329,7 +330,7 @@ def shard_filename(path, tag, shard_num, total_shards):
def
shuffle_records
(
fname
):
def
shuffle_records
(
fname
):
"""Shuffle records in a single file."""
"""Shuffle records in a single file."""
logging
.
info
(
"Shuffling records in file %s"
%
fname
)
logging
.
info
(
"Shuffling records in file %s"
,
fname
)
# Rename file prior to shuffling
# Rename file prior to shuffling
tmp_fname
=
six
.
ensure_str
(
fname
)
+
".unshuffled"
tmp_fname
=
six
.
ensure_str
(
fname
)
+
".unshuffled"
...
@@ -349,7 +350,7 @@ def shuffle_records(fname):
...
@@ -349,7 +350,7 @@ def shuffle_records(fname):
for
count
,
record
in
enumerate
(
records
):
for
count
,
record
in
enumerate
(
records
):
w
.
write
(
record
)
w
.
write
(
record
)
if
count
>
0
and
count
%
100000
==
0
:
if
count
>
0
and
count
%
100000
==
0
:
logging
.
info
(
"
\t
Writing record: %d"
%
count
)
logging
.
info
(
"
\t
Writing record: %d"
,
count
)
tf
.
gfile
.
Remove
(
tmp_fname
)
tf
.
gfile
.
Remove
(
tmp_fname
)
...
@@ -372,7 +373,7 @@ def all_exist(filepaths):
...
@@ -372,7 +373,7 @@ def all_exist(filepaths):
def
make_dir
(
path
):
def
make_dir
(
path
):
if
not
tf
.
gfile
.
Exists
(
path
):
if
not
tf
.
gfile
.
Exists
(
path
):
logging
.
info
(
"Creating directory %s"
%
path
)
logging
.
info
(
"Creating directory %s"
,
path
)
tf
.
gfile
.
MakeDirs
(
path
)
tf
.
gfile
.
MakeDirs
(
path
)
...
@@ -395,7 +396,10 @@ def main(unused_argv):
...
@@ -395,7 +396,10 @@ def main(unused_argv):
train_files_flat
=
train_files
[
"inputs"
]
+
train_files
[
"targets"
]
train_files_flat
=
train_files
[
"inputs"
]
+
train_files
[
"targets"
]
vocab_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
VOCAB_FILE
)
vocab_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
VOCAB_FILE
)
subtokenizer
=
tokenizer
.
Subtokenizer
.
init_from_files
(
subtokenizer
=
tokenizer
.
Subtokenizer
.
init_from_files
(
vocab_file
,
train_files_flat
,
_TARGET_VOCAB_SIZE
,
_TARGET_THRESHOLD
,
vocab_file
,
train_files_flat
,
_TARGET_VOCAB_SIZE
,
_TARGET_THRESHOLD
,
min_count
=
None
if
FLAGS
.
search
else
_TRAIN_DATA_MIN_COUNT
)
min_count
=
None
if
FLAGS
.
search
else
_TRAIN_DATA_MIN_COUNT
)
logging
.
info
(
"Step 4/5: Compiling training and evaluation data"
)
logging
.
info
(
"Step 4/5: Compiling training and evaluation data"
)
...
@@ -404,12 +408,11 @@ def main(unused_argv):
...
@@ -404,12 +408,11 @@ def main(unused_argv):
# Tokenize and save data as Examples in the TFRecord format.
# Tokenize and save data as Examples in the TFRecord format.
logging
.
info
(
"Step 5/5: Preprocessing and saving data"
)
logging
.
info
(
"Step 5/5: Preprocessing and saving data"
)
train_tfrecord_files
=
encode_and_save_files
(
train_tfrecord_files
=
encode_and_save_files
(
subtokenizer
,
FLAGS
.
data_dir
,
subtokenizer
,
FLAGS
.
data_dir
,
compiled_train_files
,
_TRAIN_TAG
,
compiled_train_files
,
_TRAIN_TAG
,
_TRAIN_SHARDS
)
_TRAIN_SHARDS
)
encode_and_save_files
(
encode_and_save_files
(
subtokenizer
,
FLAGS
.
data_dir
,
compiled_eval_files
,
subtokenizer
,
FLAGS
.
data_dir
,
compiled_eval_files
,
_EVAL_TAG
,
_EVAL_TAG
,
_EVAL_SHARDS
)
_EVAL_SHARDS
)
for
fname
in
train_tfrecord_files
:
for
fname
in
train_tfrecord_files
:
shuffle_records
(
fname
)
shuffle_records
(
fname
)
...
@@ -418,15 +421,20 @@ def main(unused_argv):
...
@@ -418,15 +421,20 @@ def main(unused_argv):
def
define_data_download_flags
():
def
define_data_download_flags
():
"""Add flags specifying data download arguments."""
"""Add flags specifying data download arguments."""
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
name
=
"data_dir"
,
short_name
=
"dd"
,
default
=
"/tmp/translate_ende"
,
name
=
"data_dir"
,
short_name
=
"dd"
,
default
=
"/tmp/translate_ende"
,
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
"Directory for where the translate_ende_wmt32k dataset is saved."
))
"Directory for where the translate_ende_wmt32k dataset is saved."
))
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
name
=
"raw_dir"
,
short_name
=
"rd"
,
default
=
"/tmp/translate_ende_raw"
,
name
=
"raw_dir"
,
short_name
=
"rd"
,
default
=
"/tmp/translate_ende_raw"
,
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
"Path where the raw data will be downloaded and extracted."
))
"Path where the raw data will be downloaded and extracted."
))
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
name
=
"search"
,
default
=
False
,
name
=
"search"
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
"If set, use binary search to find the vocabulary set with size"
"If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)."
%
_TARGET_VOCAB_SIZE
))
"closest to the target size (%d)."
%
_TARGET_VOCAB_SIZE
))
...
...
official/nlp/transformer/data_pipeline.py
View file @
999fae62
...
@@ -87,8 +87,9 @@ def _parse_example(serialized_example):
...
@@ -87,8 +87,9 @@ def _parse_example(serialized_example):
def
_filter_max_length
(
example
,
max_length
=
256
):
def
_filter_max_length
(
example
,
max_length
=
256
):
"""Indicates whether the example's length is lower than the maximum length."""
"""Indicates whether the example's length is lower than the maximum length."""
return
tf
.
logical_and
(
tf
.
size
(
example
[
0
])
<=
max_length
,
return
tf
.
logical_and
(
tf
.
size
(
example
[
1
])
<=
max_length
)
tf
.
size
(
example
[
0
])
<=
max_length
,
tf
.
size
(
example
[
1
])
<=
max_length
)
def
_get_example_length
(
example
):
def
_get_example_length
(
example
):
...
@@ -97,8 +98,9 @@ def _get_example_length(example):
...
@@ -97,8 +98,9 @@ def _get_example_length(example):
return
length
return
length
def
_create_min_max_boundaries
(
def
_create_min_max_boundaries
(
max_length
,
max_length
,
min_boundary
=
_MIN_BOUNDARY
,
boundary_scale
=
_BOUNDARY_SCALE
):
min_boundary
=
_MIN_BOUNDARY
,
boundary_scale
=
_BOUNDARY_SCALE
):
"""Create min and max boundary lists up to max_length.
"""Create min and max boundary lists up to max_length.
For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
...
@@ -165,8 +167,8 @@ def _batch_examples(dataset, batch_size, max_length):
...
@@ -165,8 +167,8 @@ def _batch_examples(dataset, batch_size, max_length):
# TODO(xunkai): investigate if removing code branching improves performance.
# TODO(xunkai): investigate if removing code branching improves performance.
conditions_c
=
tf
.
logical_and
(
conditions_c
=
tf
.
logical_and
(
tf
.
less_equal
(
buckets_min
,
seq_length
),
tf
.
less_equal
(
buckets_min
,
seq_length
),
tf
.
less
(
seq_length
,
tf
.
less
(
seq_length
,
buckets_max
))
buckets_max
))
bucket_id
=
tf
.
reduce_min
(
tf
.
where
(
conditions_c
))
bucket_id
=
tf
.
reduce_min
(
tf
.
where
(
conditions_c
))
return
bucket_id
return
bucket_id
...
@@ -183,16 +185,23 @@ def _batch_examples(dataset, batch_size, max_length):
...
@@ -183,16 +185,23 @@ def _batch_examples(dataset, batch_size, max_length):
# lengths as well. Resulting lengths of inputs and targets can differ.
# lengths as well. Resulting lengths of inputs and targets can differ.
return
grouped_dataset
.
padded_batch
(
bucket_batch_size
,
([
None
],
[
None
]))
return
grouped_dataset
.
padded_batch
(
bucket_batch_size
,
([
None
],
[
None
]))
return
dataset
.
apply
(
tf
.
data
.
experimental
.
group_by_window
(
return
dataset
.
apply
(
key_func
=
example_to_bucket_id
,
tf
.
data
.
experimental
.
group_by_window
(
reduce_func
=
batching_fn
,
key_func
=
example_to_bucket_id
,
window_size
=
None
,
reduce_func
=
batching_fn
,
window_size_func
=
window_size_fn
))
window_size
=
None
,
window_size_func
=
window_size_fn
))
def
_read_and_batch_from_files
(
file_pattern
,
batch_size
,
max_length
,
max_io_parallelism
,
shuffle
,
repeat
,
def
_read_and_batch_from_files
(
file_pattern
,
static_batch
=
False
,
num_replicas
=
1
,
ctx
=
None
):
batch_size
,
max_length
,
max_io_parallelism
,
shuffle
,
repeat
,
static_batch
=
False
,
num_replicas
=
1
,
ctx
=
None
):
"""Create dataset where each item is a dict of "inputs" and "targets".
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
Args:
...
@@ -204,20 +213,18 @@ def _read_and_batch_from_files(
...
@@ -204,20 +213,18 @@ def _read_and_batch_from_files(
repeat: Number of times to repeat the dataset. If None, the dataset is
repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever.
repeated forever.
static_batch: Whether the batches in the dataset should have static shapes.
static_batch: Whether the batches in the dataset should have static shapes.
If True, the input is batched so that every batch has the
If True, the input is batched so that every batch has the shape
shape [batch_size // max_length, max_length]. If False, the input is
[batch_size // max_length, max_length]. If False, the input is grouped by
grouped by length, and batched so that batches may have different
length, and batched so that batches may have different
shapes [N, M], where:
shapes [N, M], where: N * M <= batch_size M <= max_length In general, this
N * M <= batch_size
setting should be False. Dynamic shapes allow the inputs to be grouped
M <= max_length
so that the number of padding tokens is minimized, and helps model
In general, this setting should be False. Dynamic shapes allow the inputs
training. In cases where the input shape must be static (e.g. running on
to be grouped so that the number of padding tokens is minimized, and helps
TPU), this setting should be set to True.
model training. In cases where the input shape must be static
(e.g. running on TPU), this setting should be set to True.
num_replicas: Number of GPUs or other workers. We will generate global
num_replicas: Number of GPUs or other workers. We will generate global
batches, and each global batch is equally divisible by number of replicas.
batches, and each global batch is equally divisible by number of replicas.
Currently it is only effective when static_batch==True. TODO: make it
Currently it is only effective when static_batch==True. TODO: make it
effective when static_batch=False.
effective when static_batch=False.
ctx: Input context.
ctx: Input context.
Returns:
Returns:
...
@@ -240,8 +247,8 @@ def _read_and_batch_from_files(
...
@@ -240,8 +247,8 @@ def _read_and_batch_from_files(
# Parse each tf.Example into a dictionary
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
# TODO: Look into prefetch_input_elements for performance optimization.
dataset
=
dataset
.
map
(
_parse_example
,
dataset
=
dataset
.
map
(
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
_parse_example
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
# Remove examples where the input or target length exceeds the maximum length,
# Remove examples where the input or target length exceeds the maximum length,
dataset
=
dataset
.
filter
(
lambda
x
,
y
:
_filter_max_length
((
x
,
y
),
max_length
))
dataset
=
dataset
.
filter
(
lambda
x
,
y
:
_filter_max_length
((
x
,
y
),
max_length
))
...
@@ -252,7 +259,8 @@ def _read_and_batch_from_files(
...
@@ -252,7 +259,8 @@ def _read_and_batch_from_files(
# into sentences, and finally expand to a global batch. It could prove
# into sentences, and finally expand to a global batch. It could prove
# the global batch divisble for distribution strategy.
# the global batch divisble for distribution strategy.
int
(
batch_size
//
num_replicas
//
max_length
*
num_replicas
),
int
(
batch_size
//
num_replicas
//
max_length
*
num_replicas
),
([
max_length
],
[
max_length
]),
drop_remainder
=
True
)
([
max_length
],
[
max_length
]),
drop_remainder
=
True
)
else
:
else
:
# Group and batch such that each batch has examples of similar length.
# Group and batch such that each batch has examples of similar length.
# TODO(xunkai): _batch_examples might need to do something special for
# TODO(xunkai): _batch_examples might need to do something special for
...
@@ -291,10 +299,15 @@ def train_input_fn(params, ctx=None):
...
@@ -291,10 +299,15 @@ def train_input_fn(params, ctx=None):
if
params
[
"use_synthetic_data"
]:
if
params
[
"use_synthetic_data"
]:
return
_generate_synthetic_data
(
params
)
return
_generate_synthetic_data
(
params
)
return
_read_and_batch_from_files
(
return
_read_and_batch_from_files
(
file_pattern
,
params
[
"batch_size"
],
params
[
"max_length"
],
file_pattern
,
params
[
"max_io_parallelism"
],
shuffle
=
True
,
params
[
"batch_size"
],
repeat
=
params
[
"repeat_dataset"
],
static_batch
=
params
[
"static_batch"
],
params
[
"max_length"
],
num_replicas
=
params
[
"num_gpus"
],
ctx
=
ctx
)
params
[
"max_io_parallelism"
],
shuffle
=
True
,
repeat
=
params
[
"repeat_dataset"
],
static_batch
=
params
[
"static_batch"
],
num_replicas
=
params
[
"num_gpus"
],
ctx
=
ctx
)
def
eval_input_fn
(
params
,
ctx
=
None
):
def
eval_input_fn
(
params
,
ctx
=
None
):
...
@@ -303,9 +316,14 @@ def eval_input_fn(params, ctx=None):
...
@@ -303,9 +316,14 @@ def eval_input_fn(params, ctx=None):
if
params
[
"use_synthetic_data"
]:
if
params
[
"use_synthetic_data"
]:
return
_generate_synthetic_data
(
params
)
return
_generate_synthetic_data
(
params
)
return
_read_and_batch_from_files
(
return
_read_and_batch_from_files
(
file_pattern
,
params
[
"batch_size"
],
params
[
"max_length"
],
file_pattern
,
params
[
"max_io_parallelism"
],
shuffle
=
False
,
repeat
=
1
,
params
[
"batch_size"
],
static_batch
=
params
[
"static_batch"
],
num_replicas
=
params
[
"num_gpus"
],
params
[
"max_length"
],
params
[
"max_io_parallelism"
],
shuffle
=
False
,
repeat
=
1
,
static_batch
=
params
[
"static_batch"
],
num_replicas
=
params
[
"num_gpus"
],
ctx
=
ctx
)
ctx
=
ctx
)
...
...
official/nlp/transformer/embedding_layer.py
View file @
999fae62
...
@@ -60,6 +60,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
...
@@ -60,6 +60,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
Args:
Args:
inputs: An int64 tensor with shape [batch_size, length]
inputs: An int64 tensor with shape [batch_size, length]
mode: string, a valid value is one of "embedding" and "linear".
mode: string, a valid value is one of "embedding" and "linear".
Returns:
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
...
@@ -82,7 +83,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
...
@@ -82,7 +83,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
mask
=
tf
.
cast
(
tf
.
not_equal
(
inputs
,
0
),
embeddings
.
dtype
)
mask
=
tf
.
cast
(
tf
.
not_equal
(
inputs
,
0
),
embeddings
.
dtype
)
embeddings
*=
tf
.
expand_dims
(
mask
,
-
1
)
embeddings
*=
tf
.
expand_dims
(
mask
,
-
1
)
# Scale embedding by the sqrt of the hidden size
# Scale embedding by the sqrt of the hidden size
embeddings
*=
self
.
hidden_size
**
0.5
embeddings
*=
self
.
hidden_size
**
0.5
return
embeddings
return
embeddings
...
@@ -91,6 +92,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
...
@@ -91,6 +92,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
Args:
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
float32 tensor with shape [batch_size, length, vocab_size].
"""
"""
...
...
official/nlp/transformer/misc.py
View file @
999fae62
...
@@ -19,6 +19,7 @@ from __future__ import division
...
@@ -19,6 +19,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -66,28 +67,34 @@ def define_transformer_flags():
...
@@ -66,28 +67,34 @@ def define_transformer_flags():
tf_gpu_thread_mode
=
True
,
tf_gpu_thread_mode
=
True
,
datasets_num_private_threads
=
True
,
datasets_num_private_threads
=
True
,
enable_xla
=
True
,
enable_xla
=
True
,
fp16_implementation
=
True
fp16_implementation
=
True
)
)
flags_core
.
define_benchmark
()
flags_core
.
define_benchmark
()
flags_core
.
define_device
(
tpu
=
True
)
flags_core
.
define_device
(
tpu
=
True
)
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
name
=
'train_steps'
,
short_name
=
'ts'
,
default
=
300000
,
name
=
'train_steps'
,
short_name
=
'ts'
,
default
=
300000
,
help
=
flags_core
.
help_wrap
(
'The number of steps used to train.'
))
help
=
flags_core
.
help_wrap
(
'The number of steps used to train.'
))
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
name
=
'steps_between_evals'
,
short_name
=
'sbe'
,
default
=
5000
,
name
=
'steps_between_evals'
,
short_name
=
'sbe'
,
default
=
5000
,
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
'The Number of training steps to run between evaluations. This is '
'The Number of training steps to run between evaluations. This is '
'used if --train_steps is defined.'
))
'used if --train_steps is defined.'
))
flags
.
DEFINE_boolean
(
flags
.
DEFINE_boolean
(
name
=
'enable_time_history'
,
default
=
True
,
name
=
'enable_time_history'
,
default
=
True
,
help
=
'Whether to enable TimeHistory callback.'
)
help
=
'Whether to enable TimeHistory callback.'
)
flags
.
DEFINE_boolean
(
flags
.
DEFINE_boolean
(
name
=
'enable_tensorboard'
,
default
=
False
,
name
=
'enable_tensorboard'
,
default
=
False
,
help
=
'Whether to enable Tensorboard callback.'
)
help
=
'Whether to enable Tensorboard callback.'
)
flags
.
DEFINE_boolean
(
flags
.
DEFINE_boolean
(
name
=
'enable_metrics_in_training'
,
default
=
False
,
name
=
'enable_metrics_in_training'
,
default
=
False
,
help
=
'Whether to enable metrics during training.'
)
help
=
'Whether to enable metrics during training.'
)
flags
.
DEFINE_boolean
(
flags
.
DEFINE_boolean
(
name
=
'enable_mlir_bridge'
,
name
=
'enable_mlir_bridge'
,
...
@@ -100,7 +107,9 @@ def define_transformer_flags():
...
@@ -100,7 +107,9 @@ def define_transformer_flags():
# Add transformer-specific flags
# Add transformer-specific flags
flags
.
DEFINE_enum
(
flags
.
DEFINE_enum
(
name
=
'param_set'
,
short_name
=
'mp'
,
default
=
'big'
,
name
=
'param_set'
,
short_name
=
'mp'
,
default
=
'big'
,
enum_values
=
PARAMS_MAP
.
keys
(),
enum_values
=
PARAMS_MAP
.
keys
(),
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
'Parameter set to use when creating and training the model. The '
'Parameter set to use when creating and training the model. The '
...
@@ -111,7 +120,9 @@ def define_transformer_flags():
...
@@ -111,7 +120,9 @@ def define_transformer_flags():
'complete list of parameters, please see model/model_params.py.'
))
'complete list of parameters, please see model/model_params.py.'
))
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
name
=
'static_batch'
,
short_name
=
'sb'
,
default
=
False
,
name
=
'static_batch'
,
short_name
=
'sb'
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
'Whether the batches in the dataset should have static shapes. In '
'Whether the batches in the dataset should have static shapes. In '
'general, this setting should be False. Dynamic shapes allow the '
'general, this setting should be False. Dynamic shapes allow the '
...
@@ -120,7 +131,9 @@ def define_transformer_flags():
...
@@ -120,7 +131,9 @@ def define_transformer_flags():
'must be static (e.g. running on TPU), this setting will be ignored '
'must be static (e.g. running on TPU), this setting will be ignored '
'and static batching will always be used.'
))
'and static batching will always be used.'
))
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
name
=
'max_length'
,
short_name
=
'ml'
,
default
=
256
,
name
=
'max_length'
,
short_name
=
'ml'
,
default
=
256
,
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
'Max sentence length for Transformer. Default is 256. Note: Usually '
'Max sentence length for Transformer. Default is 256. Note: Usually '
'it is more effective to use a smaller max length if static_batch is '
'it is more effective to use a smaller max length if static_batch is '
...
@@ -128,30 +141,39 @@ def define_transformer_flags():
...
@@ -128,30 +141,39 @@ def define_transformer_flags():
# Flags for training with steps (may be used for debugging)
# Flags for training with steps (may be used for debugging)
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
name
=
'validation_steps'
,
short_name
=
'vs'
,
default
=
64
,
name
=
'validation_steps'
,
short_name
=
'vs'
,
default
=
64
,
help
=
flags_core
.
help_wrap
(
'The number of steps used in validation.'
))
help
=
flags_core
.
help_wrap
(
'The number of steps used in validation.'
))
# BLEU score computation
# BLEU score computation
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
name
=
'bleu_source'
,
short_name
=
'bls'
,
default
=
None
,
name
=
'bleu_source'
,
short_name
=
'bls'
,
default
=
None
,
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
'Path to source file containing text translate when calculating the '
'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
))
))
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
name
=
'bleu_ref'
,
short_name
=
'blr'
,
default
=
None
,
name
=
'bleu_ref'
,
short_name
=
'blr'
,
default
=
None
,
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
'Path to source file containing text translate when calculating the '
'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
))
))
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
name
=
'vocab_file'
,
short_name
=
'vf'
,
default
=
None
,
name
=
'vocab_file'
,
short_name
=
'vf'
,
default
=
None
,
help
=
flags_core
.
help_wrap
(
help
=
flags_core
.
help_wrap
(
'Path to subtoken vocabulary file. If data_download.py was used to '
'Path to subtoken vocabulary file. If data_download.py was used to '
'download and encode the training data, look in the data_dir to find '
'download and encode the training data, look in the data_dir to find '
'the vocab file.'
))
'the vocab file.'
))
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
name
=
'mode'
,
default
=
'train'
,
name
=
'mode'
,
default
=
'train'
,
help
=
flags_core
.
help_wrap
(
'mode: train, eval, or predict'
))
help
=
flags_core
.
help_wrap
(
'mode: train, eval, or predict'
))
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
name
=
'use_ctl'
,
name
=
'use_ctl'
,
...
@@ -188,9 +210,10 @@ def define_transformer_flags():
...
@@ -188,9 +210,10 @@ def define_transformer_flags():
'Whether to do checkpointing during training. When running under '
'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.'
))
'benchmark harness, we will avoid checkpointing.'
))
flags_core
.
set_defaults
(
data_dir
=
'/tmp/translate_ende'
,
flags_core
.
set_defaults
(
model_dir
=
'/tmp/transformer_model'
,
data_dir
=
'/tmp/translate_ende'
,
batch_size
=
None
)
model_dir
=
'/tmp/transformer_model'
,
batch_size
=
None
)
# pylint: disable=unused-variable
# pylint: disable=unused-variable
@
flags
.
multi_flags_validator
(
@
flags
.
multi_flags_validator
(
...
@@ -203,11 +226,12 @@ def define_transformer_flags():
...
@@ -203,11 +226,12 @@ def define_transformer_flags():
@
flags
.
multi_flags_validator
(
@
flags
.
multi_flags_validator
(
[
'bleu_source'
,
'bleu_ref'
,
'vocab_file'
],
[
'bleu_source'
,
'bleu_ref'
,
'vocab_file'
],
message
=
'--vocab_file must be defined if --bleu_source and --bleu_ref '
message
=
'--vocab_file must be defined if --bleu_source and --bleu_ref '
'are defined.'
)
'are defined.'
)
def
_check_bleu_vocab_file
(
flags_dict
):
def
_check_bleu_vocab_file
(
flags_dict
):
if
flags_dict
[
'bleu_source'
]
and
flags_dict
[
'bleu_ref'
]:
if
flags_dict
[
'bleu_source'
]
and
flags_dict
[
'bleu_ref'
]:
return
flags_dict
[
'vocab_file'
]
is
not
None
return
flags_dict
[
'vocab_file'
]
is
not
None
return
True
return
True
# pylint: enable=unused-variable
# pylint: enable=unused-variable
...
@@ -256,5 +280,5 @@ def update_stats(history, stats, callbacks):
...
@@ -256,5 +280,5 @@ def update_stats(history, stats, callbacks):
if
len
(
timestamp_log
)
>
1
:
if
len
(
timestamp_log
)
>
1
:
stats
[
'avg_exp_per_second'
]
=
(
stats
[
'avg_exp_per_second'
]
=
(
callback
.
batch_size
*
callback
.
log_steps
*
callback
.
batch_size
*
callback
.
log_steps
*
(
len
(
callback
.
timestamp_log
)
-
1
)
/
(
len
(
callback
.
timestamp_log
)
-
1
)
/
(
timestamp_log
[
-
1
].
timestamp
-
timestamp_log
[
0
].
timestamp
))
(
timestamp_log
[
-
1
].
timestamp
-
timestamp_log
[
0
].
timestamp
))
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment