Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dfcca061
Commit
dfcca061
authored
Aug 21, 2019
by
Reed
Browse files
Use tf.float32 instead of "float32"
parent
acdf24c5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
1 deletion
+1
-1
official/transformer/v2/transformer.py
official/transformer/v2/transformer.py
+1
-1
No files found.
official/transformer/v2/transformer.py
View file @
dfcca061
...
@@ -50,7 +50,7 @@ def create_model(params, is_train):
...
@@ -50,7 +50,7 @@ def create_model(params, is_train):
if
params
[
"enable_metrics_in_training"
]:
if
params
[
"enable_metrics_in_training"
]:
logits
=
metrics
.
MetricLayer
(
vocab_size
)([
logits
,
targets
])
logits
=
metrics
.
MetricLayer
(
vocab_size
)([
logits
,
targets
])
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
,
logits
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
"logits"
,
dtype
=
"
float32
"
)(
logits
)
dtype
=
tf
.
float32
)(
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
model
=
tf
.
keras
.
Model
([
inputs
,
targets
],
logits
)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss
=
metrics
.
transformer_loss
(
loss
=
metrics
.
transformer_loss
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment