"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "c28797e379163000b1c959fca91d57002d3174c0"
Commit 67f81649 authored by guptapriya's avatar guptapriya
Browse files

Remove loss layer

parent ffbada72
...@@ -181,25 +181,3 @@ def transformer_loss(logits, labels, smoothing, vocab_size): ...@@ -181,25 +181,3 @@ def transformer_loss(logits, labels, smoothing, vocab_size):
xentropy, weights = padded_cross_entropy_loss(logits, labels, smoothing, xentropy, weights = padded_cross_entropy_loss(logits, labels, smoothing,
vocab_size) vocab_size)
return tf.reduce_sum(xentropy) / tf.reduce_sum(weights) return tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
class LossLayer(tf.keras.layers.Layer):
"""Custom a layer of transformer loss for Transformer model."""
def __init__(self, vocab_size, label_smoothing):
super(LossLayer, self).__init__()
self.vocab_size = vocab_size
self.label_smoothing = label_smoothing
def get_config(self):
return {
"vocab_size": self.vocab_size,
"label_smoothing": self.label_smoothing,
}
def call(self, inputs):
logits, targets = inputs[0], inputs[1]
loss = transformer_loss(logits, targets, self.label_smoothing,
self.vocab_size)
self.add_loss(loss)
return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment