Commit c1f02ec3 authored by Rui Qian's avatar Rui Qian Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 465127497
parent 02ac43fc
......@@ -103,8 +103,8 @@ def top_k_by_class(predictions, labels, k=20):
Args:
predictions: A numpy matrix containing the outputs of the model. Dimensions
are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels. Dimensions are
'batch' x 'num_classes'.
k: the top k non-zero entries to preserve in each prediction.
Returns:
......@@ -143,9 +143,10 @@ def top_k_triplets(predictions, labels, k=20):
Args:
predictions: A numpy matrix containing the outputs of the model. Dimensions
are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels. Dimensions are
'batch' x 'num_classes'.
k: The number top predictions to pick.
Returns:
a sparse list of tuples in (prediction, class) format.
"""
......@@ -175,7 +176,7 @@ class EvaluationMetrics(object):
self.sum_hit_at_one = 0.0
self.sum_perr = 0.0
self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(
num_class, top_n=top_n)
num_class, filter_empty_classes=False, top_n=top_n)
self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator()
self.top_k = top_k
self.num_examples = 0
......@@ -217,9 +218,13 @@ class EvaluationMetrics(object):
return {"hit_at_one": mean_hit_at_one, "perr": mean_perr}
def get(self):
def get(self, return_per_class_ap=False):
"""Calculate the evaluation metrics for the whole epoch.
Args:
return_per_class_ap: a bool variable to determine whether return the
detailed class-wise ap for more detailed analysis. Default is `False`.
Raises:
ValueError: If no examples were accumulated.
......@@ -243,6 +248,10 @@ class EvaluationMetrics(object):
"map": mean_ap,
"gap": gap
}
if return_per_class_ap:
epoch_info_dict["per_class_ap"] = aps
return epoch_info_dict
def clear(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment