Unverified Commit cb630ffa authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Update object detection pipeline to use post_process_object_detection methods(#20004)

parent 79c720c0
......@@ -93,12 +93,11 @@ class ObjectDetectionPipeline(Pipeline):
def postprocess(self, model_outputs, threshold=0.9):
target_size = model_outputs["target_size"]
raw_annotations = self.feature_extractor.post_process(model_outputs, target_size)
raw_annotations = self.feature_extractor.post_process_object_detection(model_outputs, threshold, target_size)
raw_annotation = raw_annotations[0]
keep = raw_annotation["scores"] > threshold
scores = raw_annotation["scores"][keep]
labels = raw_annotation["labels"][keep]
boxes = raw_annotation["boxes"][keep]
scores = raw_annotation["scores"]
labels = raw_annotation["labels"]
boxes = raw_annotation["boxes"]
raw_annotation["scores"] = scores.tolist()
raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment