Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f5014889
Commit
f5014889
authored
May 18, 2020
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
May 18, 2020
Browse files
Internal change
PiperOrigin-RevId: 312194218
parent
25160730
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
6 deletions
+9
-6
official/nlp/transformer/data_pipeline.py
official/nlp/transformer/data_pipeline.py
+8
-4
official/nlp/transformer/transformer_main.py
official/nlp/transformer/transformer_main.py
+0
-2
official/nlp/transformer/transformer_main_test.py
official/nlp/transformer/transformer_main_test.py
+1
-0
No files found.
official/nlp/transformer/data_pipeline.py
View file @
f5014889
...
@@ -51,7 +51,6 @@ from __future__ import absolute_import
...
@@ -51,7 +51,6 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
math
import
os
import
os
from
absl
import
logging
from
absl
import
logging
...
@@ -157,7 +156,7 @@ def _batch_examples(dataset, batch_size, max_length):
...
@@ -157,7 +156,7 @@ def _batch_examples(dataset, batch_size, max_length):
# Create list of batch sizes for each bucket_id, so that
# Create list of batch sizes for each bucket_id, so that
# bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
# bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
bucket_batch_sizes
=
[
batch_size
//
x
for
x
in
buckets_max
]
bucket_batch_sizes
=
[
int
(
batch_size
)
//
x
for
x
in
buckets_max
]
# bucket_id will be a tensor, so convert this list to a tensor as well.
# bucket_id will be a tensor, so convert this list to a tensor as well.
bucket_batch_sizes
=
tf
.
constant
(
bucket_batch_sizes
,
dtype
=
tf
.
int64
)
bucket_batch_sizes
=
tf
.
constant
(
bucket_batch_sizes
,
dtype
=
tf
.
int64
)
...
@@ -270,7 +269,8 @@ def _read_and_batch_from_files(
...
@@ -270,7 +269,8 @@ def _read_and_batch_from_files(
def
_generate_synthetic_data
(
params
):
def
_generate_synthetic_data
(
params
):
"""Create synthetic data based on the parameter batch size."""
"""Create synthetic data based on the parameter batch size."""
batch
=
length
=
int
(
math
.
sqrt
(
params
[
"batch_size"
]))
batch_size
=
int
(
params
[
"batch_size"
]
//
params
[
"max_length"
])
length
=
params
[
"max_length"
]
dataset
=
model_helpers
.
generate_synthetic_data
(
dataset
=
model_helpers
.
generate_synthetic_data
(
input_shape
=
tf
.
TensorShape
([
length
]),
input_shape
=
tf
.
TensorShape
([
length
]),
input_value
=
1
,
input_value
=
1
,
...
@@ -279,7 +279,11 @@ def _generate_synthetic_data(params):
...
@@ -279,7 +279,11 @@ def _generate_synthetic_data(params):
label_value
=
1
,
label_value
=
1
,
label_dtype
=
tf
.
int64
,
label_dtype
=
tf
.
int64
,
)
)
return
dataset
.
batch
(
batch
,
drop_remainder
=
True
)
if
params
[
"static_batch"
]:
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
else
:
dataset
=
dataset
.
padded_batch
(
batch_size
,
([
None
],
[
None
]))
return
dataset
def
train_input_fn
(
params
,
ctx
=
None
):
def
train_input_fn
(
params
,
ctx
=
None
):
...
...
official/nlp/transformer/transformer_main.py
View file @
f5014889
...
@@ -168,8 +168,6 @@ class TransformerTask(object):
...
@@ -168,8 +168,6 @@ class TransformerTask(object):
tpu_address
=
flags_obj
.
tpu
or
""
)
tpu_address
=
flags_obj
.
tpu
or
""
)
if
self
.
use_tpu
:
if
self
.
use_tpu
:
params
[
"num_replicas"
]
=
self
.
distribution_strategy
.
num_replicas_in_sync
params
[
"num_replicas"
]
=
self
.
distribution_strategy
.
num_replicas_in_sync
if
not
params
[
"static_batch"
]:
raise
ValueError
(
"TPU requires static batch for input data."
)
else
:
else
:
logging
.
info
(
"Running transformer with num_gpus = %d"
,
num_gpus
)
logging
.
info
(
"Running transformer with num_gpus = %d"
,
num_gpus
)
...
...
official/nlp/transformer/transformer_main_test.py
View file @
f5014889
...
@@ -61,6 +61,7 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -61,6 +61,7 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS
.
train_steps
=
2
FLAGS
.
train_steps
=
2
FLAGS
.
validation_steps
=
1
FLAGS
.
validation_steps
=
1
FLAGS
.
batch_size
=
8
FLAGS
.
batch_size
=
8
FLAGS
.
max_length
=
1
FLAGS
.
num_gpus
=
1
FLAGS
.
num_gpus
=
1
FLAGS
.
distribution_strategy
=
'off'
FLAGS
.
distribution_strategy
=
'off'
FLAGS
.
dtype
=
'fp32'
FLAGS
.
dtype
=
'fp32'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment