"official/transformer/model/beam_search.py" did not exist on "dea7ecf6492b02e2ced3fbba858942b2b43d3029"
utils.py 623 Bytes
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import tensorflow as tf


def accuracy_metrics(y_true, logits):
    return {'enas_acc': accuracy(y_true, logits)}

def accuracy(y_true, logits):
    # y_true: shape=(batch_size) or (batch_size,1), type=integer
    # logits: shape=(batch_size, num_of_classes), type=float
    # returns float
    batch_size = y_true.shape[0]
    y_true = tf.squeeze(y_true)
    y_pred = tf.math.argmax(logits, axis=1)
    y_pred = tf.cast(y_pred, y_true.dtype)
    equal = tf.cast(y_pred == y_true, tf.int32)
    return tf.math.reduce_sum(equal).numpy() / batch_size