Unverified Commit cb630ffa authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Update object detection pipeline to use post_process_object_detection methods(#20004)

parent 79c720c0
...@@ -93,12 +93,11 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -93,12 +93,11 @@ class ObjectDetectionPipeline(Pipeline):
def postprocess(self, model_outputs, threshold=0.9): def postprocess(self, model_outputs, threshold=0.9):
target_size = model_outputs["target_size"] target_size = model_outputs["target_size"]
raw_annotations = self.feature_extractor.post_process(model_outputs, target_size) raw_annotations = self.feature_extractor.post_process_object_detection(model_outputs, threshold, target_size)
raw_annotation = raw_annotations[0] raw_annotation = raw_annotations[0]
keep = raw_annotation["scores"] > threshold scores = raw_annotation["scores"]
scores = raw_annotation["scores"][keep] labels = raw_annotation["labels"]
labels = raw_annotation["labels"][keep] boxes = raw_annotation["boxes"]
boxes = raw_annotation["boxes"][keep]
raw_annotation["scores"] = scores.tolist() raw_annotation["scores"] = scores.tolist()
raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels] raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment