Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0344c550
Unverified
Commit
0344c550
authored
May 15, 2018
by
Katherine Wu
Committed by
GitHub
May 15, 2018
Browse files
Fix transformer loss (#4270)
parent
461fc094
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
3 deletions
+6
-3
official/transformer/transformer_main.py
official/transformer/transformer_main.py
+4
-1
official/transformer/utils/metrics.py
official/transformer/utils/metrics.py
+2
-2
No files found.
official/transformer/transformer_main.py
View file @
0344c550
...
@@ -81,9 +81,12 @@ def model_fn(features, labels, mode, params):
...
@@ -81,9 +81,12 @@ def model_fn(features, labels, mode, params):
logits
=
output
logits
=
output
# Calculate model loss.
# Calculate model loss.
# xentropy contains the cross entropy loss of every nonpadding token in the
# targets.
xentropy
,
weights
=
metrics
.
padded_cross_entropy_loss
(
xentropy
,
weights
=
metrics
.
padded_cross_entropy_loss
(
logits
,
targets
,
params
.
label_smoothing
,
params
.
vocab_size
)
logits
,
targets
,
params
.
label_smoothing
,
params
.
vocab_size
)
loss
=
tf
.
reduce_sum
(
xentropy
*
weights
)
/
tf
.
reduce_sum
(
weights
)
# Compute the weighted mean of the cross entropy losses
loss
=
tf
.
reduce_sum
(
xentropy
)
/
tf
.
reduce_sum
(
weights
)
# Save loss as named tensor that will be logged with the logging hook.
# Save loss as named tensor that will be logged with the logging hook.
tf
.
identity
(
loss
,
"cross_entropy"
)
tf
.
identity
(
loss
,
"cross_entropy"
)
...
...
official/transformer/utils/metrics.py
View file @
0344c550
...
@@ -58,8 +58,8 @@ def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
...
@@ -58,8 +58,8 @@ def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
smoothing: Label smoothing constant, used to determine the on and off values
smoothing: Label smoothing constant, used to determine the on and off values
vocab_size: int size of the vocabulary
vocab_size: int size of the vocabulary
Returns:
Returns:
Returns
a
float32 tensor with
shape
Returns
the cross entropy loss and weight tensors:
float32 tensor
s
with
[batch_size, max(length_logits, length_labels)]
shape
[batch_size, max(length_logits, length_labels)]
"""
"""
with
tf
.
name_scope
(
"loss"
,
[
logits
,
labels
]):
with
tf
.
name_scope
(
"loss"
,
[
logits
,
labels
]):
logits
,
labels
=
_pad_tensors_to_same_length
(
logits
,
labels
)
logits
,
labels
=
_pad_tensors_to_same_length
(
logits
,
labels
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment