Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
772
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
184 additions
and
1080 deletions
+184
-1080
official/nlp/modeling/ops/__init__.py
official/nlp/modeling/ops/__init__.py
+3
-1
official/nlp/modeling/ops/beam_search.py
official/nlp/modeling/ops/beam_search.py
+20
-14
official/nlp/modeling/ops/beam_search_test.py
official/nlp/modeling/ops/beam_search_test.py
+8
-5
official/nlp/modeling/ops/decoding_module.py
official/nlp/modeling/ops/decoding_module.py
+43
-18
official/nlp/modeling/ops/decoding_module_test.py
official/nlp/modeling/ops/decoding_module_test.py
+2
-1
official/nlp/modeling/ops/sampling_module.py
official/nlp/modeling/ops/sampling_module.py
+47
-16
official/nlp/modeling/ops/segment_extractor.py
official/nlp/modeling/ops/segment_extractor.py
+1
-1
official/nlp/modeling/ops/segment_extractor_test.py
official/nlp/modeling/ops/segment_extractor_test.py
+1
-1
official/nlp/optimization.py
official/nlp/optimization.py
+9
-129
official/nlp/projects/__init__.py
official/nlp/projects/__init__.py
+0
-14
official/nlp/projects/bigbird/__init__.py
official/nlp/projects/bigbird/__init__.py
+0
-14
official/nlp/projects/bigbird/encoder.py
official/nlp/projects/bigbird/encoder.py
+0
-238
official/nlp/projects/bigbird/encoder_test.py
official/nlp/projects/bigbird/encoder_test.py
+0
-63
official/nlp/projects/bigbird/experiment_configs.py
official/nlp/projects/bigbird/experiment_configs.py
+0
-100
official/nlp/projects/teams/__init__.py
official/nlp/projects/teams/__init__.py
+0
-14
official/nlp/projects/teams/teams_experiments_test.py
official/nlp/projects/teams/teams_experiments_test.py
+0
-38
official/nlp/projects/triviaqa/__init__.py
official/nlp/projects/triviaqa/__init__.py
+0
-14
official/nlp/projects/triviaqa/train.py
official/nlp/projects/triviaqa/train.py
+0
-384
official/nlp/serving/__init__.py
official/nlp/serving/__init__.py
+15
-0
official/nlp/serving/export_savedmodel.py
official/nlp/serving/export_savedmodel.py
+35
-15
No files found.
Too many changes to show.
To preserve performance only
772 of 772+
files are displayed.
Plain diff
Email patch
official/nlp/modeling/ops/__init__.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,5 +14,7 @@
"""Ops package definition."""
from
official.nlp.modeling.ops.beam_search
import
sequence_beam_search
from
official.nlp.modeling.ops.beam_search
import
SequenceBeamSearch
from
official.nlp.modeling.ops.sampling_module
import
SamplingModule
from
official.nlp.modeling.ops.segment_extractor
import
get_next_sentence_labels
from
official.nlp.modeling.ops.segment_extractor
import
get_sentence_order_labels
official/nlp/modeling/ops/beam_search.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -107,18 +107,18 @@ class SequenceBeamSearch(tf.Module):
max_decode_length
,
eos_id
,
padded_decode
,
dtype
=
tf
.
float32
):
dtype
=
tf
.
float32
,
decoding_name
=
None
):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are: ids ->
A tensor with shape [batch_size * beam_size, index]. index -> A
scalar. cache -> A nested dictionary of tensors [batch_size *
beam_size, ...].
The function must return a tuple of logits and the updated cache: logits
-> A tensor with shape [batch * beam_size, vocab_size]. updated cache
-> A nested dictionary with the same structure as the input cache.
symbols_to_logits_fn: A function to provide logits, which is the interface
to the Transformer model. The passed in arguments are: ids -> A tensor
with shape [batch_size * beam_size, index]. index -> A scalar. cache ->
A nested dictionary of tensors [batch_size * beam_size, ...]. The
function must return a tuple of logits and the updated cache: logits ->
A tensor with shape [batch * beam_size, vocab_size]. updated cache -> A
nested dictionary with the same structure as the input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
beam_size: An integer, number of beams for beam search.
...
...
@@ -130,6 +130,7 @@ class SequenceBeamSearch(tf.Module):
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
"""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
vocab_size
=
vocab_size
...
...
@@ -139,6 +140,7 @@ class SequenceBeamSearch(tf.Module):
self
.
eos_id
=
eos_id
self
.
padded_decode
=
padded_decode
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
decoding_name
=
decoding_name
def
search
(
self
,
initial_ids
,
initial_cache
):
"""Beam search for sequences with highest scores.
...
...
@@ -204,7 +206,7 @@ class SequenceBeamSearch(tf.Module):
candidate_log_probs
=
_log_prob_from_logits
(
logits
)
# Calculate new log probabilities if each of the alive sequences were
# extended # by the
the
candidate IDs.
# extended # by the candidate IDs.
# Shape [batch_size, beam_size, vocab_size]
log_probs
=
candidate_log_probs
+
tf
.
expand_dims
(
alive_log_probs
,
axis
=
2
)
...
...
@@ -370,7 +372,8 @@ class SequenceBeamSearch(tf.Module):
_search_step
,
loop_vars
=
[
state
],
shape_invariants
=
[
state_shapes
],
parallel_iterations
=
1
))
parallel_iterations
=
1
,
name
=
self
.
decoding_name
))
finished_state
=
finished_state
[
0
]
return
self
.
_process_finished_state
(
finished_state
)
...
...
@@ -587,7 +590,8 @@ def sequence_beam_search(symbols_to_logits_fn,
max_decode_length
,
eos_id
,
padded_decode
=
False
,
dtype
=
"float32"
):
dtype
=
"float32"
,
decoding_name
=
None
):
"""Search for sequence of subtoken ids with the largest probability.
Args:
...
...
@@ -612,13 +616,15 @@ def sequence_beam_search(symbols_to_logits_fn,
beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
sbs
=
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
)
max_decode_length
,
eos_id
,
padded_decode
,
dtype
,
decoding_name
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
...
official/nlp/modeling/ops/beam_search_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -60,10 +60,12 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
y
)
@
parameterized
.
named_parameters
([
(
'padded_decode_true'
,
True
),
(
'padded_decode_false'
,
False
),
(
'padded_decode_true_with_name'
,
True
,
'decoding'
),
(
'padded_decode_false_with_name'
,
False
,
'decoding'
),
(
'padded_decode_true_without_name'
,
True
,
None
),
(
'padded_decode_false_without_name'
,
False
,
None
),
])
def
test_sequence_beam_search
(
self
,
padded_decode
):
def
test_sequence_beam_search
(
self
,
padded_decode
,
name
):
# batch_size*beam_size, max_decode_length, vocab_size
probabilities
=
tf
.
constant
([[[
0.2
,
0.7
,
0.1
],
[
0.5
,
0.3
,
0.2
],
[
0.1
,
0.8
,
0.1
]],
...
...
@@ -91,7 +93,8 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
max_decode_length
=
3
,
eos_id
=
9
,
padded_decode
=
padded_decode
,
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
,
decoding_name
=
name
)
self
.
assertAllEqual
([[[
0
,
1
,
0
,
1
],
[
0
,
1
,
1
,
2
]]],
predictions
)
...
...
official/nlp/modeling/ops/decoding_module.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -15,14 +15,14 @@
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""
import
abc
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
import
tensorflow
as
tf
from
tensorflow.python.framework
import
dtypes
from
official.modeling
import
tf_utils
Output
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]
Output
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
Optional
[
tf
.
Tensor
]
]
InternalState
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
,
Dict
]
InitialState
=
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]
...
...
@@ -46,6 +46,10 @@ class StateKeys:
# the previous iteration.
ALIVE_CACHE
=
"ALIVE_CACHE"
# The initial model state/cache after model processing the initial token.
# The cache will be filled if extra_cache_output is true.
INITIAL_OUTPUT_CACHE
=
"INITIAL_OUTPUT_CACHE"
# Top finished sequences for each batch item.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
# shorter than CUR_INDEX + 1 are padded with 0s.
...
...
@@ -108,7 +112,9 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
def
__init__
(
self
,
length_normalization_fn
:
Callable
[[
int
,
tf
.
DType
],
float
],
dtype
:
tf
.
DType
=
tf
.
float32
):
dtype
:
tf
.
DType
=
tf
.
float32
,
decoding_name
:
Optional
[
str
]
=
None
,
extra_cache_output
:
bool
=
False
):
"""Initialize the Decoding Module.
Args:
...
...
@@ -116,31 +122,39 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
parameter. Function accepts input as length, dtype and returns float.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
extra_cache_output: If true, the first cache will be in the states.
"""
self
.
length_normalization_fn
=
length_normalization_fn
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
decoding_name
=
decoding_name
def
generate
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_cache
:
Dict
[
str
,
tf
.
Tensor
])
->
Output
:
initial_cache
:
Dict
[
str
,
tf
.
Tensor
],
initial_log_probs
:
Optional
[
tf
.
Tensor
]
=
None
)
->
Output
:
"""Implements the decoding strategy (beam_search or sampling).
Args:
initial_ids: initial ids to pass into the symbols_to_logits_fn.
int tensor
with shape [batch_size, 1]
initial_ids: initial ids to pass into the symbols_to_logits_fn.
int tensor
with shape [batch_size, 1]
initial_cache: dictionary for caching model outputs from previous step.
initial_log_probs: Optionally initial log probs if there is a prefix
sequence we want to start to decode from.
Returns:
Tuple of tensors representing
finished_sequence: shape [batch, max_seq_length]
finished_scores: [batch]
first_cache: The cache after init token
"""
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
self
.
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
state
,
state_shapes
=
self
.
_create_initial_state
(
initial_ids
,
initial_cach
e
,
batch_size
)
state
,
state_shapes
=
self
.
_create_initial_state
(
initial_ids
,
initial_cache
,
batch_siz
e
,
initial_log_probs
)
def
_generate_step
(
state
):
topk_seq
,
topk_log_probs
,
topk_ids
,
new_cache
=
self
.
_grow_alive_seq
(
...
...
@@ -160,6 +174,17 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
}
new_state
.
update
(
alive_state
)
new_state
.
update
(
finished_state
)
if
self
.
extra_cache_output
:
i
=
state
[
StateKeys
.
CUR_INDEX
]
old_cache
=
state
[
StateKeys
.
INITIAL_OUTPUT_CACHE
]
def
update_with_cache
(
new_state
,
cache
):
"""Updates new_state with cache."""
new_state
.
update
({
StateKeys
.
INITIAL_OUTPUT_CACHE
:
cache
})
tf
.
cond
(
tf
.
equal
(
i
,
0
),
lambda
:
update_with_cache
(
new_state
,
new_cache
),
lambda
:
update_with_cache
(
new_state
,
old_cache
))
return
[
new_state
]
finished_state
=
tf
.
nest
.
map_structure
(
...
...
@@ -169,15 +194,18 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
_generate_step
,
loop_vars
=
[
state
],
shape_invariants
=
[
state_shapes
],
parallel_iterations
=
1
))
parallel_iterations
=
1
,
name
=
self
.
decoding_name
))
final_state
=
self
.
_process_finished_state
(
finished_state
[
0
])
return
final_state
@
abc
.
abstractmethod
def
_create_initial_state
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_cache
:
Dict
[
str
,
tf
.
Tensor
],
batch_size
:
int
)
->
InitialState
:
def
_create_initial_state
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_cache
:
Dict
[
str
,
tf
.
Tensor
],
batch_size
:
int
,
initial_log_probs
:
Optional
[
tf
.
Tensor
]
=
None
)
->
InitialState
:
"""Return initial state dictionary and its shape invariants."""
pass
...
...
@@ -277,6 +305,3 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
return
dtypes
.
float16
.
max
else
:
raise
AssertionError
(
"Invalid dtype: %s"
%
self
.
dtype
)
official/nlp/modeling/ops/decoding_module_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -29,6 +29,7 @@ class TestSubclass(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def
__init__
(
self
,
length_normalization_fn
=
length_normalization
,
extra_cache_output
=
True
,
dtype
=
tf
.
float32
):
super
(
TestSubclass
,
self
).
__init__
(
length_normalization_fn
=
length_normalization
,
dtype
=
dtype
)
...
...
official/nlp/modeling/ops/sampling_module.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -55,6 +55,8 @@ def sample_top_k(logits, top_k):
Returns:
Logits with top_k filtering applied.
"""
top_k
=
tf
.
clip_by_value
(
top_k
,
clip_value_min
=
1
,
clip_value_max
=
tf
.
shape
(
logits
)[
-
1
])
top_k_logits
=
tf
.
math
.
top_k
(
logits
,
k
=
top_k
)
indices_to_remove
=
logits
<
tf
.
expand_dims
(
top_k_logits
[
0
][...,
-
1
],
-
1
)
top_k_logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
...
...
@@ -160,7 +162,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
top_p
=
1.0
,
sample_temperature
=
0.0
,
enable_greedy
:
bool
=
True
,
dtype
:
tf
.
DType
=
tf
.
float32
):
dtype
:
tf
.
DType
=
tf
.
float32
,
decoding_name
:
Optional
[
str
]
=
None
,
extra_cache_output
:
bool
=
False
):
"""Initialize sampling module."""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
length_normalization_fn
=
length_normalization_fn
...
...
@@ -174,8 +178,13 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self
.
sample_temperature
=
tf
.
convert_to_tensor
(
sample_temperature
,
dtype
=
tf
.
float32
)
self
.
enable_greedy
=
enable_greedy
self
.
decoding_name
=
decoding_name
self
.
extra_cache_output
=
extra_cache_output
super
(
SamplingModule
,
self
).
__init__
(
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
)
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
,
decoding_name
=
decoding_name
,
extra_cache_output
=
extra_cache_output
)
def
_grow_alive_seq
(
self
,
state
:
Dict
[
str
,
Any
],
...
...
@@ -241,10 +250,13 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
topk_seq
=
tf
.
concat
([
alive_seq
,
topk_ids
],
axis
=-
1
)
return
topk_seq
,
topk_log_probs
,
topk_ids
,
new_cache
def
_create_initial_state
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_cache
:
Dict
[
str
,
tf
.
Tensor
],
batch_size
:
int
)
->
decoding_module
.
InitialState
:
def
_create_initial_state
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_cache
:
Dict
[
str
,
tf
.
Tensor
],
batch_size
:
int
,
initial_log_probs
:
Optional
[
tf
.
Tensor
]
=
None
)
->
decoding_module
.
InitialState
:
"""Return initial state dictionary and its shape invariants."""
for
key
,
value
in
initial_cache
.
items
():
for
inner_value
in
tf
.
nest
.
flatten
(
value
):
...
...
@@ -264,8 +276,11 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
alive_seq
=
tf
.
tile
(
alive_seq
,
[
1
,
self
.
max_decode_length
+
1
])
# Initial log probabilities with shape [batch_size, 1].
initial_log_probs
=
tf
.
constant
([[
0.
]],
dtype
=
self
.
dtype
)
alive_log_probs
=
tf
.
tile
(
initial_log_probs
,
[
batch_size
,
1
])
if
initial_log_probs
is
None
:
initial_log_probs
=
tf
.
constant
([[
0.
]],
dtype
=
self
.
dtype
)
alive_log_probs
=
tf
.
tile
(
initial_log_probs
,
[
batch_size
,
1
])
else
:
alive_log_probs
=
initial_log_probs
alive_cache
=
initial_cache
...
...
@@ -294,16 +309,14 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module
.
StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
decoding_module
.
StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
(
[
batch_size
,
self
.
max_decode_length
+
1
]),
tf
.
TensorShape
([
batch_size
,
self
.
max_decode_length
+
1
]),
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
batch_size
,
1
]),
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
tf
.
nest
.
map_structure
(
lambda
state
:
state
.
get_shape
(),
alive_cache
),
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
(
[
batch_size
,
self
.
max_decode_length
+
1
]),
tf
.
TensorShape
([
batch_size
,
self
.
max_decode_length
+
1
]),
decoding_module
.
StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
batch_size
,
1
]),
decoding_module
.
StateKeys
.
FINISHED_FLAGS
:
...
...
@@ -318,9 +331,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
1
]),
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
tf
.
nest
.
map_structure
(
decoding_module
.
get_shape_keep_last_dim
,
alive_cache
),
tf
.
nest
.
map_structure
(
decoding_module
.
get_shape_keep_last_dim
,
alive_cache
),
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
None
]),
decoding_module
.
StateKeys
.
FINISHED_SCORES
:
...
...
@@ -329,6 +341,22 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
tf
.
TensorShape
([
None
,
1
])
}
if
self
.
extra_cache_output
:
state
.
update
(
{
decoding_module
.
StateKeys
.
INITIAL_OUTPUT_CACHE
:
alive_cache
})
if
self
.
padded_decode
:
state_shape_invariants
.
update
({
decoding_module
.
StateKeys
.
INITIAL_OUTPUT_CACHE
:
tf
.
nest
.
map_structure
(
lambda
state
:
state
.
get_shape
(),
alive_cache
)
})
else
:
state_shape_invariants
.
update
({
decoding_module
.
StateKeys
.
INITIAL_OUTPUT_CACHE
:
tf
.
nest
.
map_structure
(
decoding_module
.
get_shape_keep_last_dim
,
alive_cache
),
})
return
state
,
state_shape_invariants
def
_get_new_alive_state
(
self
,
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
...
...
@@ -422,6 +450,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
finished_scores
)
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
if
self
.
extra_cache_output
:
return
finished_seq
,
finished_scores
,
finished_state
[
decoding_module
.
StateKeys
.
INITIAL_OUTPUT_CACHE
]
return
finished_seq
,
finished_scores
def
_continue_search
(
self
,
state
)
->
tf
.
Tensor
:
...
...
official/nlp/modeling/ops/segment_extractor.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/modeling/ops/segment_extractor_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/optimization.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions and classes related to optimization (weight updates)."""
import
re
"""Legacy functions and classes related to optimization."""
from
absl
import
logging
import
gin
import
tensorflow
as
tf
import
tensorflow_addons.optimizers
as
tfa_optimizers
from
official.modeling.optimization
import
legacy_adamw
AdamWeightDecay
=
legacy_adamw
.
AdamWeightDecay
class
WarmUp
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
...
...
@@ -70,13 +71,15 @@ def create_optimizer(init_lr,
num_warmup_steps
,
end_lr
=
0.0
,
optimizer_type
=
'adamw'
,
beta_1
=
0.9
):
beta_1
=
0.9
,
poly_power
=
1.0
):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
lr_schedule
=
tf
.
keras
.
optimizers
.
schedules
.
PolynomialDecay
(
initial_learning_rate
=
init_lr
,
decay_steps
=
num_train_steps
,
end_learning_rate
=
end_lr
)
end_learning_rate
=
end_lr
,
power
=
poly_power
)
if
num_warmup_steps
:
lr_schedule
=
WarmUp
(
initial_learning_rate
=
init_lr
,
...
...
@@ -105,126 +108,3 @@ def create_optimizer(init_lr,
raise
ValueError
(
'Unsupported optimizer type: '
,
optimizer_type
)
return
optimizer
class
AdamWeightDecay
(
tf
.
keras
.
optimizers
.
Adam
):
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
Just adding the square of the weights to the loss function is *not* the
correct way of using L2 regularization/weight decay with Adam, since that will
interact with the m and v parameters in strange ways.
Instead we want to decay the weights in a manner that doesn't interact with
the m/v parameters. This is equivalent to adding the square of the weights to
the loss with plain (non-momentum) SGD.
"""
def
__init__
(
self
,
learning_rate
=
0.001
,
beta_1
=
0.9
,
beta_2
=
0.999
,
epsilon
=
1e-7
,
amsgrad
=
False
,
weight_decay_rate
=
0.0
,
include_in_weight_decay
=
None
,
exclude_from_weight_decay
=
None
,
gradient_clip_norm
=
1.0
,
name
=
'AdamWeightDecay'
,
**
kwargs
):
super
(
AdamWeightDecay
,
self
).
__init__
(
learning_rate
,
beta_1
,
beta_2
,
epsilon
,
amsgrad
,
name
,
**
kwargs
)
self
.
weight_decay_rate
=
weight_decay_rate
self
.
gradient_clip_norm
=
gradient_clip_norm
self
.
_include_in_weight_decay
=
include_in_weight_decay
self
.
_exclude_from_weight_decay
=
exclude_from_weight_decay
logging
.
info
(
'gradient_clip_norm=%f'
,
gradient_clip_norm
)
@
classmethod
def
from_config
(
cls
,
config
):
"""Creates an optimizer from its config with WarmUp custom object."""
custom_objects
=
{
'WarmUp'
:
WarmUp
}
return
super
(
AdamWeightDecay
,
cls
).
from_config
(
config
,
custom_objects
=
custom_objects
)
def
_prepare_local
(
self
,
var_device
,
var_dtype
,
apply_state
):
super
(
AdamWeightDecay
,
self
).
_prepare_local
(
var_device
,
var_dtype
,
# pytype: disable=attribute-error # typed-keras
apply_state
)
apply_state
[(
var_device
,
var_dtype
)][
'weight_decay_rate'
]
=
tf
.
constant
(
self
.
weight_decay_rate
,
name
=
'adam_weight_decay_rate'
)
def
_decay_weights_op
(
self
,
var
,
learning_rate
,
apply_state
):
do_decay
=
self
.
_do_use_weight_decay
(
var
.
name
)
if
do_decay
:
return
var
.
assign_sub
(
learning_rate
*
var
*
apply_state
[(
var
.
device
,
var
.
dtype
.
base_dtype
)][
'weight_decay_rate'
],
use_locking
=
self
.
_use_locking
)
return
tf
.
no_op
()
def
apply_gradients
(
self
,
grads_and_vars
,
name
=
None
,
experimental_aggregate_gradients
=
True
):
grads
,
tvars
=
list
(
zip
(
*
grads_and_vars
))
if
experimental_aggregate_gradients
and
self
.
gradient_clip_norm
>
0.0
:
# when experimental_aggregate_gradients = False, apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient
# and passed the allreduced grads_and_vars. For now, the
# clip_by_global_norm will be moved to before the explicit allreduce to
# keep the math the same as TF 1 and pre TF 2.2 implementation.
(
grads
,
_
)
=
tf
.
clip_by_global_norm
(
grads
,
clip_norm
=
self
.
gradient_clip_norm
)
return
super
(
AdamWeightDecay
,
self
).
apply_gradients
(
zip
(
grads
,
tvars
),
name
=
name
,
experimental_aggregate_gradients
=
experimental_aggregate_gradients
)
def
_get_lr
(
self
,
var_device
,
var_dtype
,
apply_state
):
"""Retrieves the learning rate with the given state."""
if
apply_state
is
None
:
return
self
.
_decayed_lr_t
[
var_dtype
],
{}
apply_state
=
apply_state
or
{}
coefficients
=
apply_state
.
get
((
var_device
,
var_dtype
))
if
coefficients
is
None
:
coefficients
=
self
.
_fallback_apply_state
(
var_device
,
var_dtype
)
apply_state
[(
var_device
,
var_dtype
)]
=
coefficients
return
coefficients
[
'lr_t'
],
dict
(
apply_state
=
apply_state
)
def
_resource_apply_dense
(
self
,
grad
,
var
,
apply_state
=
None
):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
with
tf
.
control_dependencies
([
decay
]):
return
super
(
AdamWeightDecay
,
self
).
_resource_apply_dense
(
grad
,
var
,
**
kwargs
)
# pytype: disable=attribute-error # typed-keras
def
_resource_apply_sparse
(
self
,
grad
,
var
,
indices
,
apply_state
=
None
):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
with
tf
.
control_dependencies
([
decay
]):
return
super
(
AdamWeightDecay
,
self
).
_resource_apply_sparse
(
grad
,
var
,
indices
,
**
kwargs
)
# pytype: disable=attribute-error # typed-keras
def
get_config
(
self
):
config
=
super
(
AdamWeightDecay
,
self
).
get_config
()
config
.
update
({
'weight_decay_rate'
:
self
.
weight_decay_rate
,
})
return
config
def
_do_use_weight_decay
(
self
,
param_name
):
"""Whether to use L2 weight decay for `param_name`."""
if
self
.
weight_decay_rate
==
0
:
return
False
if
self
.
_include_in_weight_decay
:
for
r
in
self
.
_include_in_weight_decay
:
if
re
.
search
(
r
,
param_name
)
is
not
None
:
return
True
if
self
.
_exclude_from_weight_decay
:
for
r
in
self
.
_exclude_from_weight_decay
:
if
re
.
search
(
r
,
param_name
)
is
not
None
:
return
False
return
True
official/nlp/projects/__init__.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/nlp/projects/bigbird/__init__.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/nlp/projects/bigbird/encoder.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import
tensorflow
as
tf
from
official.modeling
import
activations
from
official.nlp
import
modeling
from
official.nlp.modeling
import
layers
from
official.nlp.projects.bigbird
import
recompute_grad
from
official.nlp.projects.bigbird
import
recomputing_dropout
_MAX_SEQ_LEN
=
4096
class
RecomputeTransformerLayer
(
layers
.
TransformerScaffold
):
"""Transformer layer that recomputes the forward pass during backpropagation."""
def
call
(
self
,
inputs
,
training
=
None
):
emb
,
mask
=
inputs
def
f
(
*
args
):
# recompute_grad can only handle tensor inputs. so we enumerate the
# nested input [emb, mask] as follows:
# args[0]: emb
# args[1]: mask[0] = band_mask
# args[2]: mask[1] = encoder_from_mask
# args[3]: mask[2] = encoder_to_mask
# args[4]: mask[3] = blocked_encoder_mask
x
=
super
(
RecomputeTransformerLayer
,
self
).
call
([
args
[
0
],
[
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
]]],
training
=
training
)
return
x
f
=
recompute_grad
.
recompute_grad
(
f
)
return
f
(
emb
,
*
mask
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
BigBirdEncoder
(
tf
.
keras
.
Model
):
"""Transformer-based encoder network with BigBird attentions.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_position_embeddings: The maximum length of position embeddings that this
encoder can consume. If None, max_position_embeddings uses the value from
sequence length. This determines the variable shape for positional
embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
block_size: int. A BigBird Attention parameter: size of block in from/to
sequences.
num_rand_blocks: int. A BigBird Attention parameter: number of random chunks
per row.
activation: The activation to use for the transformer layers.
dropout_rate: The dropout rate to use for the transformer layers.
attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
use_gradient_checkpointing: Use gradient checkpointing to trade-off compute
for memory.
"""
def
__init__
(
self
,
vocab_size
,
hidden_size
=
768
,
num_layers
=
12
,
num_attention_heads
=
12
,
max_position_embeddings
=
_MAX_SEQ_LEN
,
type_vocab_size
=
16
,
intermediate_size
=
3072
,
block_size
=
64
,
num_rand_blocks
=
3
,
activation
=
activations
.
gelu
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
embedding_width
=
None
,
use_gradient_checkpointing
=
False
,
**
kwargs
):
activation
=
tf
.
keras
.
activations
.
get
(
activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
if
use_gradient_checkpointing
:
tf
.
keras
.
layers
.
Dropout
=
recomputing_dropout
.
RecomputingDropout
layer_cls
=
RecomputeTransformerLayer
else
:
layer_cls
=
layers
.
TransformerScaffold
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
'vocab_size'
:
vocab_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'num_attention_heads'
:
num_attention_heads
,
'max_position_embeddings'
:
max_position_embeddings
,
'type_vocab_size'
:
type_vocab_size
,
'intermediate_size'
:
intermediate_size
,
'block_size'
:
block_size
,
'num_rand_blocks'
:
num_rand_blocks
,
'activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'dropout_rate'
:
dropout_rate
,
'attention_dropout_rate'
:
attention_dropout_rate
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'embedding_width'
:
embedding_width
,
}
word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
if
embedding_width
is
None
:
embedding_width
=
hidden_size
self
.
_embedding_layer
=
modeling
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
name
=
'word_embeddings'
)
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
# Always uses dynamic slicing for simplicity.
self
.
_position_embedding_layer
=
modeling
.
layers
.
PositionEmbedding
(
initializer
=
initializer
,
max_length
=
max_position_embeddings
,
name
=
'position_embedding'
)
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
self
.
_type_embedding_layer
=
modeling
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
use_one_hot
=
True
,
name
=
'type_embeddings'
)
type_embeddings
=
self
.
_type_embedding_layer
(
type_ids
)
embeddings
=
tf
.
keras
.
layers
.
Add
()(
[
word_embeddings
,
position_embeddings
,
type_embeddings
])
self
.
_embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
embeddings
=
self
.
_embedding_norm_layer
(
embeddings
)
embeddings
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
embeddings
)
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if
embedding_width
!=
hidden_size
:
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
kernel_initializer
=
initializer
,
name
=
'embedding_projection'
)
embeddings
=
self
.
_embedding_projection
(
embeddings
)
self
.
_transformer_layers
=
[]
data
=
embeddings
masks
=
layers
.
BigBirdMasks
(
block_size
=
block_size
)(
data
,
mask
)
encoder_outputs
=
[]
attn_head_dim
=
hidden_size
//
num_attention_heads
for
i
in
range
(
num_layers
):
layer
=
layer_cls
(
num_attention_heads
,
intermediate_size
,
activation
,
attention_cls
=
layers
.
BigBirdAttention
,
attention_cfg
=
dict
(
num_heads
=
num_attention_heads
,
key_dim
=
attn_head_dim
,
kernel_initializer
=
initializer
,
from_block_size
=
block_size
,
to_block_size
=
block_size
,
num_rand_blocks
=
num_rand_blocks
,
max_rand_mask_length
=
max_position_embeddings
,
seed
=
i
),
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
dropout_rate
,
kernel_initializer
=
initializer
)
self
.
_transformer_layers
.
append
(
layer
)
data
=
layer
([
data
,
masks
])
encoder_outputs
.
append
(
data
)
outputs
=
dict
(
sequence_output
=
encoder_outputs
[
-
1
],
encoder_outputs
=
encoder_outputs
)
super
().
__init__
(
inputs
=
[
word_ids
,
mask
,
type_ids
],
outputs
=
outputs
,
**
kwargs
)
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
def
get_embedding_layer
(
self
):
return
self
.
_embedding_layer
def
get_config
(
self
):
return
self
.
_config_dict
@
property
def
transformer_layers
(
self
):
"""List of Transformer layers in the encoder."""
return
self
.
_transformer_layers
@
property
def
pooler_layer
(
self
):
"""The pooler dense layer after the transformer layers."""
return
self
.
_pooler_layer
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/nlp/projects/bigbird/encoder_test.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.bigbird.encoder."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.projects.bigbird
import
encoder
class
BigBirdEncoderTest
(
tf
.
test
.
TestCase
):
def
test_encoder
(
self
):
sequence_length
=
1024
batch_size
=
2
vocab_size
=
1024
network
=
encoder
.
BigBirdEncoder
(
num_layers
=
1
,
vocab_size
=
1024
,
max_position_embeddings
=
4096
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
outputs
=
network
([
word_id_data
,
mask_data
,
type_id_data
])
self
.
assertEqual
(
outputs
[
"sequence_output"
].
shape
,
(
batch_size
,
sequence_length
,
768
))
def
test_save_restore
(
self
):
sequence_length
=
1024
batch_size
=
2
vocab_size
=
1024
network
=
encoder
.
BigBirdEncoder
(
num_layers
=
1
,
vocab_size
=
1024
,
max_position_embeddings
=
4096
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
inputs
=
dict
(
input_word_ids
=
word_id_data
,
input_mask
=
mask_data
,
input_type_ids
=
type_id_data
)
ref_outputs
=
network
(
inputs
)
model_path
=
self
.
get_temp_dir
()
+
"/model"
network
.
save
(
model_path
)
loaded
=
tf
.
keras
.
models
.
load_model
(
model_path
)
outputs
=
loaded
(
inputs
)
self
.
assertAllClose
(
outputs
[
"sequence_output"
],
ref_outputs
[
"sequence_output"
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/projects/bigbird/experiment_configs.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bigbird experiment configurations."""
# pylint: disable=g-doc-return-or-yield,line-too-long
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.nlp.data
import
question_answering_dataloader
from
official.nlp.data
import
sentence_prediction_dataloader
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
@
exp_factory
.
register_config_factory
(
'bigbird/glue'
)
def
bigbird_glue
()
->
cfg
.
ExperimentConfig
:
r
"""BigBird GLUE."""
config
=
cfg
.
ExperimentConfig
(
task
=
sentence_prediction
.
SentencePredictionConfig
(
train_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(),
validation_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(
is_training
=
False
,
drop_remainder
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
3e-5
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
config
.
task
.
model
.
encoder
.
type
=
'bigbird'
return
config
@
exp_factory
.
register_config_factory
(
'bigbird/squad'
)
def
bigbird_squad
()
->
cfg
.
ExperimentConfig
:
r
"""BigBird Squad V1/V2."""
config
=
cfg
.
ExperimentConfig
(
task
=
question_answering
.
QuestionAnsweringConfig
(
train_data
=
question_answering_dataloader
.
QADataConfig
(),
validation_data
=
question_answering_dataloader
.
QADataConfig
()),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
8e-5
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
config
.
task
.
model
.
encoder
.
type
=
'bigbird'
return
config
official/nlp/projects/teams/__init__.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/nlp/projects/teams/teams_experiments_test.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for teams_experiments."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
class
TeamsExperimentsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'teams/pretraining'
,))
def
test_teams_experiments
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
cfg
.
DataConfig
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/projects/triviaqa/__init__.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/nlp/projects/triviaqa/train.py
deleted
100644 → 0
View file @
9485aa1d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TriviaQA training script."""
import
collections
import
contextlib
import
functools
import
json
import
operator
import
os
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
gin
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
import
sentencepiece
as
spm
from
official.nlp
import
optimization
as
nlp_optimization
from
official.nlp.configs
import
encoders
from
official.nlp.projects.triviaqa
import
evaluation
from
official.nlp.projects.triviaqa
import
inputs
from
official.nlp.projects.triviaqa
import
modeling
from
official.nlp.projects.triviaqa
import
prediction
flags
.
DEFINE_string
(
'data_dir'
,
None
,
'Data directory for TensorFlow Datasets.'
)
flags
.
DEFINE_string
(
'validation_gold_path'
,
None
,
'Path to golden validation. Usually, the wikipedia-dev.json file.'
)
flags
.
DEFINE_string
(
'model_dir'
,
None
,
'Directory for checkpoints and summaries.'
)
flags
.
DEFINE_string
(
'model_config_path'
,
None
,
'JSON file containing model coniguration.'
)
flags
.
DEFINE_string
(
'sentencepiece_model_path'
,
None
,
'Path to sentence piece model.'
)
flags
.
DEFINE_enum
(
'encoder'
,
'bigbird'
,
[
'bert'
,
'bigbird'
,
'albert'
,
'mobilebert'
],
'Which transformer encoder model to use.'
)
flags
.
DEFINE_integer
(
'bigbird_block_size'
,
64
,
'Size of blocks for sparse block attention.'
)
flags
.
DEFINE_string
(
'init_checkpoint_path'
,
None
,
'Path from which to initialize weights.'
)
flags
.
DEFINE_integer
(
'train_sequence_length'
,
4096
,
'Maximum number of tokens for training.'
)
flags
.
DEFINE_integer
(
'train_global_sequence_length'
,
320
,
'Maximum number of global tokens for training.'
)
flags
.
DEFINE_integer
(
'validation_sequence_length'
,
4096
,
'Maximum number of tokens for validation.'
)
flags
.
DEFINE_integer
(
'validation_global_sequence_length'
,
320
,
'Maximum number of global tokens for validation.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'Size of batch.'
)
flags
.
DEFINE_string
(
'master'
,
''
,
'Address of the TPU master.'
)
flags
.
DEFINE_integer
(
'decode_top_k'
,
8
,
'Maximum number of tokens to consider for begin/end.'
)
flags
.
DEFINE_integer
(
'decode_max_size'
,
16
,
'Maximum number of sentence pieces in an answer.'
)
flags
.
DEFINE_float
(
'dropout_rate'
,
0.1
,
'Dropout rate for hidden layers.'
)
flags
.
DEFINE_float
(
'attention_dropout_rate'
,
0.3
,
'Dropout rate for attention layers.'
)
flags
.
DEFINE_float
(
'label_smoothing'
,
1e-1
,
'Degree of label smoothing.'
)
flags
.
DEFINE_multi_string
(
'gin_bindings'
,
[],
'Gin bindings to override the values set in the config files'
)
FLAGS
=
flags
.
FLAGS
@
contextlib
.
contextmanager
def
worker_context
():
if
FLAGS
.
master
:
with
tf
.
device
(
'/job:worker'
)
as
d
:
yield
d
else
:
yield
def
read_sentencepiece_model
(
path
):
with
tf
.
io
.
gfile
.
GFile
(
path
,
'rb'
)
as
file
:
processor
=
spm
.
SentencePieceProcessor
()
processor
.
LoadFromSerializedProto
(
file
.
read
())
return
processor
# Rename old BERT v1 configuration parameters.
_MODEL_CONFIG_REPLACEMENTS
=
{
'num_hidden_layers'
:
'num_layers'
,
'attention_probs_dropout_prob'
:
'attention_dropout_rate'
,
'hidden_dropout_prob'
:
'dropout_rate'
,
'hidden_act'
:
'hidden_activation'
,
'window_size'
:
'block_size'
,
}
def
read_model_config
(
encoder
,
path
,
bigbird_block_size
=
None
)
->
encoders
.
EncoderConfig
:
"""Merges the JSON configuration into the encoder configuration."""
with
tf
.
io
.
gfile
.
GFile
(
path
)
as
f
:
model_config
=
json
.
load
(
f
)
for
key
,
value
in
_MODEL_CONFIG_REPLACEMENTS
.
items
():
if
key
in
model_config
:
model_config
[
value
]
=
model_config
.
pop
(
key
)
model_config
[
'attention_dropout_rate'
]
=
FLAGS
.
attention_dropout_rate
model_config
[
'dropout_rate'
]
=
FLAGS
.
dropout_rate
model_config
[
'block_size'
]
=
bigbird_block_size
encoder_config
=
encoders
.
EncoderConfig
(
type
=
encoder
)
# Override the default config with those loaded from the JSON file.
encoder_config_keys
=
encoder_config
.
get
().
as_dict
().
keys
()
overrides
=
{}
for
key
,
value
in
model_config
.
items
():
if
key
in
encoder_config_keys
:
overrides
[
key
]
=
value
else
:
logging
.
warning
(
'Ignoring config parameter %s=%s'
,
key
,
value
)
encoder_config
.
get
().
override
(
overrides
)
return
encoder_config
@
gin
.
configurable
(
denylist
=
[
'model'
,
'strategy'
,
'train_dataset'
,
'model_dir'
,
'init_checkpoint_path'
,
'evaluate_fn'
,
])
def
fit
(
model
,
strategy
,
train_dataset
,
model_dir
,
init_checkpoint_path
=
None
,
evaluate_fn
=
None
,
learning_rate
=
1e-5
,
learning_rate_polynomial_decay_rate
=
1.
,
weight_decay_rate
=
1e-1
,
num_warmup_steps
=
5000
,
num_decay_steps
=
51000
,
num_epochs
=
6
):
"""Train and evaluate."""
hparams
=
dict
(
learning_rate
=
learning_rate
,
num_decay_steps
=
num_decay_steps
,
num_warmup_steps
=
num_warmup_steps
,
num_epochs
=
num_epochs
,
weight_decay_rate
=
weight_decay_rate
,
dropout_rate
=
FLAGS
.
dropout_rate
,
attention_dropout_rate
=
FLAGS
.
attention_dropout_rate
,
label_smoothing
=
FLAGS
.
label_smoothing
)
logging
.
info
(
hparams
)
learning_rate_schedule
=
nlp_optimization
.
WarmUp
(
learning_rate
,
tf
.
keras
.
optimizers
.
schedules
.
PolynomialDecay
(
learning_rate
,
num_decay_steps
,
end_learning_rate
=
0.
,
power
=
learning_rate_polynomial_decay_rate
),
num_warmup_steps
)
with
strategy
.
scope
():
optimizer
=
nlp_optimization
.
AdamWeightDecay
(
learning_rate_schedule
,
weight_decay_rate
=
weight_decay_rate
,
epsilon
=
1e-6
,
exclude_from_weight_decay
=
[
'LayerNorm'
,
'layer_norm'
,
'bias'
])
model
.
compile
(
optimizer
,
loss
=
modeling
.
SpanOrCrossEntropyLoss
())
def
init_fn
(
init_checkpoint_path
):
ckpt
=
tf
.
train
.
Checkpoint
(
encoder
=
model
.
encoder
)
ckpt
.
restore
(
init_checkpoint_path
).
assert_existing_objects_matched
()
with
worker_context
():
ckpt_manager
=
tf
.
train
.
CheckpointManager
(
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
optimizer
),
model_dir
,
max_to_keep
=
None
,
init_fn
=
(
functools
.
partial
(
init_fn
,
init_checkpoint_path
)
if
init_checkpoint_path
else
None
))
with
strategy
.
scope
():
ckpt_manager
.
restore_or_initialize
()
val_summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
model_dir
,
'val'
))
best_exact_match
=
0.
for
epoch
in
range
(
len
(
ckpt_manager
.
checkpoints
),
num_epochs
):
model
.
fit
(
train_dataset
,
callbacks
=
[
tf
.
keras
.
callbacks
.
TensorBoard
(
model_dir
,
write_graph
=
False
),
])
ckpt_path
=
ckpt_manager
.
save
()
if
evaluate_fn
is
None
:
continue
metrics
=
evaluate_fn
()
logging
.
info
(
'Epoch %d: %s'
,
epoch
+
1
,
metrics
)
if
best_exact_match
<
metrics
[
'exact_match'
]:
best_exact_match
=
metrics
[
'exact_match'
]
model
.
save
(
os
.
path
.
join
(
model_dir
,
'export'
),
include_optimizer
=
False
)
logging
.
info
(
'Exporting %s as SavedModel.'
,
ckpt_path
)
with
val_summary_writer
.
as_default
():
for
name
,
data
in
metrics
.
items
():
tf
.
summary
.
scalar
(
name
,
data
,
epoch
+
1
)
def
evaluate
(
sp_processor
,
features_map_fn
,
labels_map_fn
,
logits_fn
,
decode_logits_fn
,
split_and_pad_fn
,
distribute_strategy
,
validation_dataset
,
ground_truth
):
"""Run evaluation."""
loss_metric
=
tf
.
keras
.
metrics
.
Mean
()
@
tf
.
function
def
update_loss
(
y
,
logits
):
loss_fn
=
modeling
.
SpanOrCrossEntropyLoss
(
reduction
=
tf
.
keras
.
losses
.
Reduction
.
NONE
)
return
loss_metric
(
loss_fn
(
y
,
logits
))
predictions
=
collections
.
defaultdict
(
list
)
for
_
,
(
features
,
labels
)
in
validation_dataset
.
enumerate
():
token_ids
=
features
[
'token_ids'
]
y
=
labels_map_fn
(
token_ids
,
labels
)
x
=
split_and_pad_fn
(
features_map_fn
(
features
))
logits
=
tf
.
concat
(
distribute_strategy
.
experimental_local_results
(
logits_fn
(
x
)),
0
)
logits
=
logits
[:
features
[
'token_ids'
].
shape
[
0
]]
update_loss
(
y
,
logits
)
end_limit
=
token_ids
.
row_lengths
()
-
1
# inclusive
begin
,
end
,
scores
=
decode_logits_fn
(
logits
,
end_limit
)
answers
=
prediction
.
decode_answer
(
features
[
'context'
],
begin
,
end
,
features
[
'token_offsets'
],
end_limit
).
numpy
()
for
_
,
(
qid
,
token_id
,
offset
,
score
,
answer
)
in
enumerate
(
zip
(
features
[
'qid'
].
numpy
(),
tf
.
gather
(
features
[
'token_ids'
],
begin
,
batch_dims
=
1
).
numpy
(),
tf
.
gather
(
features
[
'token_offsets'
],
begin
,
batch_dims
=
1
).
numpy
(),
scores
,
answers
)):
if
not
answer
:
continue
if
sp_processor
.
IdToPiece
(
int
(
token_id
)).
startswith
(
'▁'
)
and
offset
>
0
:
answer
=
answer
[
1
:]
predictions
[
qid
.
decode
(
'utf-8'
)].
append
((
score
,
answer
.
decode
(
'utf-8'
)))
predictions
=
{
qid
:
evaluation
.
normalize_answer
(
sorted
(
answers
,
key
=
operator
.
itemgetter
(
0
),
reverse
=
True
)[
0
][
1
])
for
qid
,
answers
in
predictions
.
items
()
}
metrics
=
evaluation
.
evaluate_triviaqa
(
ground_truth
,
predictions
,
mute
=
True
)
metrics
[
'loss'
]
=
loss_metric
.
result
().
numpy
()
return
metrics
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
gin
.
parse_config
(
FLAGS
.
gin_bindings
)
model_config
=
read_model_config
(
FLAGS
.
encoder
,
FLAGS
.
model_config_path
,
bigbird_block_size
=
FLAGS
.
bigbird_block_size
)
logging
.
info
(
model_config
.
get
().
as_dict
())
# Configure input processing.
sp_processor
=
read_sentencepiece_model
(
FLAGS
.
sentencepiece_model_path
)
features_map_fn
=
functools
.
partial
(
inputs
.
features_map_fn
,
local_radius
=
FLAGS
.
bigbird_block_size
,
relative_pos_max_distance
=
24
,
use_hard_g2l_mask
=
True
,
padding_id
=
sp_processor
.
PieceToId
(
'<pad>'
),
eos_id
=
sp_processor
.
PieceToId
(
'</s>'
),
null_id
=
sp_processor
.
PieceToId
(
'<empty>'
),
cls_id
=
sp_processor
.
PieceToId
(
'<ans>'
),
sep_id
=
sp_processor
.
PieceToId
(
'<sep_0>'
))
train_features_map_fn
=
tf
.
function
(
functools
.
partial
(
features_map_fn
,
sequence_length
=
FLAGS
.
train_sequence_length
,
global_sequence_length
=
FLAGS
.
train_global_sequence_length
),
autograph
=
False
)
train_labels_map_fn
=
tf
.
function
(
functools
.
partial
(
inputs
.
labels_map_fn
,
sequence_length
=
FLAGS
.
train_sequence_length
))
# Connect to TPU cluster.
if
FLAGS
.
master
:
resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
master
)
tf
.
config
.
experimental_connect_to_cluster
(
resolver
)
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
resolver
)
strategy
=
tf
.
distribute
.
TPUStrategy
(
resolver
)
else
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
# Initialize datasets.
with
worker_context
():
_
=
tf
.
random
.
get_global_generator
()
train_dataset
=
inputs
.
read_batches
(
FLAGS
.
data_dir
,
tfds
.
Split
.
TRAIN
,
FLAGS
.
batch_size
,
shuffle
=
True
,
drop_final_batch
=
True
)
validation_dataset
=
inputs
.
read_batches
(
FLAGS
.
data_dir
,
tfds
.
Split
.
VALIDATION
,
FLAGS
.
batch_size
)
def
train_map_fn
(
x
,
y
):
features
=
train_features_map_fn
(
x
)
labels
=
modeling
.
smooth_labels
(
FLAGS
.
label_smoothing
,
train_labels_map_fn
(
x
[
'token_ids'
],
y
),
features
[
'question_lengths'
],
features
[
'token_ids'
])
return
features
,
labels
train_dataset
=
train_dataset
.
map
(
train_map_fn
,
16
).
prefetch
(
16
)
# Initialize model and compile.
with
strategy
.
scope
():
model
=
modeling
.
TriviaQaModel
(
model_config
,
FLAGS
.
train_sequence_length
)
logits_fn
=
tf
.
function
(
functools
.
partial
(
prediction
.
distributed_logits_fn
,
model
))
decode_logits_fn
=
tf
.
function
(
functools
.
partial
(
prediction
.
decode_logits
,
FLAGS
.
decode_top_k
,
FLAGS
.
decode_max_size
))
split_and_pad_fn
=
tf
.
function
(
functools
.
partial
(
prediction
.
split_and_pad
,
strategy
,
FLAGS
.
batch_size
))
# Evaluation strategy.
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
validation_gold_path
)
as
f
:
ground_truth
=
{
datum
[
'QuestionId'
]:
datum
[
'Answer'
]
for
datum
in
json
.
load
(
f
)[
'Data'
]
}
validation_features_map_fn
=
tf
.
function
(
functools
.
partial
(
features_map_fn
,
sequence_length
=
FLAGS
.
validation_sequence_length
,
global_sequence_length
=
FLAGS
.
validation_global_sequence_length
),
autograph
=
False
)
validation_labels_map_fn
=
tf
.
function
(
functools
.
partial
(
inputs
.
labels_map_fn
,
sequence_length
=
FLAGS
.
validation_sequence_length
))
evaluate_fn
=
functools
.
partial
(
evaluate
,
sp_processor
=
sp_processor
,
features_map_fn
=
validation_features_map_fn
,
labels_map_fn
=
validation_labels_map_fn
,
logits_fn
=
logits_fn
,
decode_logits_fn
=
decode_logits_fn
,
split_and_pad_fn
=
split_and_pad_fn
,
distribute_strategy
=
strategy
,
validation_dataset
=
validation_dataset
,
ground_truth
=
ground_truth
)
logging
.
info
(
'Model initialized. Beginning training fit loop.'
)
fit
(
model
,
strategy
,
train_dataset
,
FLAGS
.
model_dir
,
FLAGS
.
init_checkpoint_path
,
evaluate_fn
)
if
__name__
==
'__main__'
:
flags
.
mark_flags_as_required
([
'model_config_path'
,
'model_dir'
,
'sentencepiece_model_path'
,
'validation_gold_path'
])
app
.
run
(
main
)
official/nlp/serving/__init__.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/nlp/serving/export_savedmodel.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,12 +13,14 @@
# limitations under the License.
"""A binary/library to export TF-NLP serving `SavedModel`."""
import
dataclasses
import
os
from
typing
import
Any
,
Dict
,
Text
from
absl
import
app
from
absl
import
flags
import
dataclasses
import
yaml
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling
import
hyperparams
...
...
@@ -29,6 +31,7 @@ from official.nlp.tasks import masked_lm
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
from
official.nlp.tasks
import
translation
FLAGS
=
flags
.
FLAGS
...
...
@@ -40,7 +43,9 @@ SERVING_MODULES = {
question_answering
.
QuestionAnsweringTask
:
serving_modules
.
QuestionAnswering
,
tagging
.
TaggingTask
:
serving_modules
.
Tagging
serving_modules
.
Tagging
,
translation
.
TranslationTask
:
serving_modules
.
Translation
}
...
...
@@ -67,6 +72,12 @@ def define_flags():
flags
.
DEFINE_bool
(
"convert_tpu"
,
False
,
""
)
flags
.
DEFINE_multi_integer
(
"allowed_batch_size"
,
None
,
"Allowed batch sizes for batching ops."
)
flags
.
DEFINE_integer
(
"num_batch_threads"
,
4
,
"Number of threads to do TPU batching."
)
flags
.
DEFINE_integer
(
"batch_timeout_micros"
,
100000
,
"TPU batch function timeout in microseconds."
)
flags
.
DEFINE_integer
(
"max_enqueued_batches"
,
1000
,
"Max number of batches in queue for TPU batching."
)
def
lookup_export_module
(
task
:
base_task
.
Task
):
...
...
@@ -125,21 +136,30 @@ def main(_):
if
FLAGS
.
convert_tpu
:
# pylint: disable=g-import-not-at-top
from
cloud_tpu.inference_converter
import
converter_cli
from
cloud_tpu.inference_converter
import
converter_options_pb2
from
cloud_tpu.inference_converter_v2
import
converter_options_v2_pb2
from
cloud_tpu.inference_converter_v2.python
import
converter
tpu_dir
=
os
.
path
.
join
(
export_dir
,
"tpu"
)
options
=
converter_options_pb2
.
ConverterO
ptions
()
batch_o
ptions
=
[]
if
FLAGS
.
allowed_batch_size
is
not
None
:
allowed_batch_sizes
=
sorted
(
FLAGS
.
allowed_batch_size
)
options
.
batch_options
.
num_batch_threads
=
4
options
.
batch_options
.
max_batch_size
=
allowed_batch_sizes
[
-
1
]
options
.
batch_options
.
batch_timeout_micros
=
100000
options
.
batch_options
.
allowed_batch_sizes
[:]
=
allowed_batch_sizes
options
.
batch_options
.
max_enqueued_batches
=
1000
converter_cli
.
ConvertSavedModel
(
export_dir
,
tpu_dir
,
function_alias
=
"tpu_candidate"
,
options
=
options
,
graph_rewrite_only
=
True
)
batch_option
=
converter_options_v2_pb2
.
BatchOptionsV2
(
num_batch_threads
=
FLAGS
.
num_batch_threads
,
max_batch_size
=
allowed_batch_sizes
[
-
1
],
batch_timeout_micros
=
FLAGS
.
batch_timeout_micros
,
allowed_batch_sizes
=
allowed_batch_sizes
,
max_enqueued_batches
=
FLAGS
.
max_enqueued_batches
)
batch_options
.
append
(
batch_option
)
converter_options
=
converter_options_v2_pb2
.
ConverterOptionsV2
(
tpu_functions
=
[
converter_options_v2_pb2
.
TpuFunction
(
function_alias
=
"tpu_candidate"
)
],
batch_options
=
batch_options
,
)
converter
.
ConvertSavedModel
(
export_dir
,
tpu_dir
,
converter_options
)
if
__name__
==
"__main__"
:
define_flags
()
...
...
Prev
1
…
21
22
23
24
25
26
27
28
29
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment