Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b72f4975
Commit
b72f4975
authored
Feb 23, 2022
by
Frederick Liu
Committed by
A. Unique TensorFlower
Feb 23, 2022
Browse files
[seq2seq] hardcode batch size
PiperOrigin-RevId: 430563213
parent
0bcb7aa0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
4 deletions
+12
-4
official/nlp/serving/serving_modules.py
official/nlp/serving/serving_modules.py
+5
-1
official/nlp/serving/serving_modules_test.py
official/nlp/serving/serving_modules_test.py
+7
-3
No files found.
official/nlp/serving/serving_modules.py
View file @
b72f4975
...
@@ -417,6 +417,8 @@ class Translation(export_base.ExportModule):
...
@@ -417,6 +417,8 @@ class Translation(export_base.ExportModule):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
class
Params
(
base_config
.
Config
):
sentencepiece_model_path
:
str
=
""
sentencepiece_model_path
:
str
=
""
# Needs to be specified if padded_decode is True/on TPUs.
batch_size
:
Optional
[
int
]
=
None
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
super
().
__init__
(
params
,
model
,
inference_step
)
...
@@ -431,6 +433,7 @@ class Translation(export_base.ExportModule):
...
@@ -431,6 +433,7 @@ class Translation(export_base.ExportModule):
"Please make sure the tokenizer generates a single token for an "
"Please make sure the tokenizer generates a single token for an "
"empty string."
)
"empty string."
)
self
.
_eos_id
=
empty_str_tokenized
.
item
()
self
.
_eos_id
=
empty_str_tokenized
.
item
()
self
.
_batch_size
=
params
.
batch_size
@
tf
.
function
@
tf
.
function
def
serve
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
def
serve
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
...
@@ -452,5 +455,6 @@ class Translation(export_base.ExportModule):
...
@@ -452,5 +455,6 @@ class Translation(export_base.ExportModule):
(
self
.
__class__
,
func_key
,
valid_keys
))
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve_text"
:
if
func_key
==
"serve_text"
:
signatures
[
signature_key
]
=
self
.
serve_text
.
get_concrete_function
(
signatures
[
signature_key
]
=
self
.
serve_text
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"text"
))
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
],
dtype
=
tf
.
string
,
name
=
"text"
))
return
signatures
return
signatures
official/nlp/serving/serving_modules_test.py
View file @
b72f4975
...
@@ -344,7 +344,10 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -344,7 +344,10 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
_
=
export_module
.
get_inference_signatures
({
"foo"
:
None
})
_
=
export_module
.
get_inference_signatures
({
"foo"
:
None
})
def
test_translation
(
self
):
@
parameterized
.
parameters
(
(
False
,
None
),
(
True
,
2
))
def
test_translation
(
self
,
padded_decode
,
batch_size
):
sp_path
=
_make_sentencepeice
(
self
.
get_temp_dir
())
sp_path
=
_make_sentencepeice
(
self
.
get_temp_dir
())
encdecoder
=
translation
.
EncDecoder
(
encdecoder
=
translation
.
EncDecoder
(
num_attention_heads
=
4
,
intermediate_size
=
256
)
num_attention_heads
=
4
,
intermediate_size
=
256
)
...
@@ -353,7 +356,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -353,7 +356,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
encoder
=
encdecoder
,
encoder
=
encdecoder
,
decoder
=
encdecoder
,
decoder
=
encdecoder
,
embedding_width
=
256
,
embedding_width
=
256
,
padded_decode
=
Fals
e
,
padded_decode
=
padded_decod
e
,
decode_max_length
=
100
),
decode_max_length
=
100
),
sentencepiece_model_path
=
sp_path
,
sentencepiece_model_path
=
sp_path
,
)
)
...
@@ -361,7 +364,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -361,7 +364,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
model
=
task
.
build_model
()
model
=
task
.
build_model
()
params
=
serving_modules
.
Translation
.
Params
(
params
=
serving_modules
.
Translation
.
Params
(
sentencepiece_model_path
=
sp_path
)
sentencepiece_model_path
=
sp_path
,
batch_size
=
batch_size
)
export_module
=
serving_modules
.
Translation
(
params
=
params
,
model
=
model
)
export_module
=
serving_modules
.
Translation
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
functions
=
export_module
.
get_inference_signatures
({
"serve_text"
:
"serving_default"
"serve_text"
:
"serving_default"
...
@@ -371,6 +374,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -371,6 +374,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
outputs
.
dtype
,
tf
.
string
)
self
.
assertEqual
(
outputs
.
dtype
,
tf
.
string
)
tmp_dir
=
self
.
get_temp_dir
()
tmp_dir
=
self
.
get_temp_dir
()
tmp_dir
=
os
.
path
.
join
(
tmp_dir
,
"padded_decode"
,
str
(
padded_decode
))
export_base_dir
=
os
.
path
.
join
(
tmp_dir
,
"export"
)
export_base_dir
=
os
.
path
.
join
(
tmp_dir
,
"export"
)
ckpt_dir
=
os
.
path
.
join
(
tmp_dir
,
"ckpt"
)
ckpt_dir
=
os
.
path
.
join
(
tmp_dir
,
"ckpt"
)
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
model
).
save
(
ckpt_dir
)
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
model
).
save
(
ckpt_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment