Commit c1f02ec3 authored by Rui Qian's avatar Rui Qian Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 465127497
parent 02ac43fc
...@@ -103,8 +103,8 @@ def top_k_by_class(predictions, labels, k=20): ...@@ -103,8 +103,8 @@ def top_k_by_class(predictions, labels, k=20):
Args: Args:
predictions: A numpy matrix containing the outputs of the model. Dimensions predictions: A numpy matrix containing the outputs of the model. Dimensions
are 'batch' x 'num_classes'. are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels. labels: A numpy matrix containing the ground truth labels. Dimensions are
Dimensions are 'batch' x 'num_classes'. 'batch' x 'num_classes'.
k: the top k non-zero entries to preserve in each prediction. k: the top k non-zero entries to preserve in each prediction.
Returns: Returns:
...@@ -143,9 +143,10 @@ def top_k_triplets(predictions, labels, k=20): ...@@ -143,9 +143,10 @@ def top_k_triplets(predictions, labels, k=20):
Args: Args:
predictions: A numpy matrix containing the outputs of the model. Dimensions predictions: A numpy matrix containing the outputs of the model. Dimensions
are 'batch' x 'num_classes'. are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels. labels: A numpy matrix containing the ground truth labels. Dimensions are
Dimensions are 'batch' x 'num_classes'. 'batch' x 'num_classes'.
k: The number top predictions to pick. k: The number top predictions to pick.
Returns: Returns:
a sparse list of tuples in (prediction, class) format. a sparse list of tuples in (prediction, class) format.
""" """
...@@ -175,7 +176,7 @@ class EvaluationMetrics(object): ...@@ -175,7 +176,7 @@ class EvaluationMetrics(object):
self.sum_hit_at_one = 0.0 self.sum_hit_at_one = 0.0
self.sum_perr = 0.0 self.sum_perr = 0.0
self.map_calculator = map_calculator.MeanAveragePrecisionCalculator( self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(
num_class, top_n=top_n) num_class, filter_empty_classes=False, top_n=top_n)
self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator() self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator()
self.top_k = top_k self.top_k = top_k
self.num_examples = 0 self.num_examples = 0
...@@ -217,9 +218,13 @@ class EvaluationMetrics(object): ...@@ -217,9 +218,13 @@ class EvaluationMetrics(object):
return {"hit_at_one": mean_hit_at_one, "perr": mean_perr} return {"hit_at_one": mean_hit_at_one, "perr": mean_perr}
def get(self): def get(self, return_per_class_ap=False):
"""Calculate the evaluation metrics for the whole epoch. """Calculate the evaluation metrics for the whole epoch.
Args:
return_per_class_ap: a bool variable to determine whether return the
detailed class-wise ap for more detailed analysis. Default is `False`.
Raises: Raises:
ValueError: If no examples were accumulated. ValueError: If no examples were accumulated.
...@@ -243,6 +248,10 @@ class EvaluationMetrics(object): ...@@ -243,6 +248,10 @@ class EvaluationMetrics(object):
"map": mean_ap, "map": mean_ap,
"gap": gap "gap": gap
} }
if return_per_class_ap:
epoch_info_dict["per_class_ap"] = aps
return epoch_info_dict return epoch_info_dict
def clear(self): def clear(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment