predict-img.py 3.79 KB
Newer Older
bailuo's avatar
init  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from cog import BasePredictor, Input, Path, BaseModel
import os
import cv2
import time
import subprocess
import numpy as np
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
from mmengine.visualization import Visualizer
from typing import Optional

MODEL_CACHE = "checkpoints"
# MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-4B/model.tar"
MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-8B/model.tar"
# MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-26B/model.tar"

class Output(BaseModel):
    img: Optional[Path]
    response: str

def download_weights(url, dest):
    start = time.time()
    print("downloading url: ", url)
    print("downloading to: ", dest)
    subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
    print("downloading took: ", time.time() - start)
    
class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""
        os.environ["TRANSFORMERS_OFFLINE"] = "1"

        # Download weights if they don't exist
        if not os.path.exists(MODEL_CACHE):
            download_weights(MODEL_URL, MODEL_CACHE)
        
        # Load model and tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(
            MODEL_CACHE,
            torch_dtype="auto",
            device_map="cuda:0",
            trust_remote_code=True,
        ).eval().cuda()

        self.tokenizer = AutoTokenizer.from_pretrained(
            MODEL_CACHE,
            trust_remote_code=True,
        )

    def predict(
        self,
        image: Path = Input(description="Input image for segmentation"),
        instruction: str = Input(description="Text instruction for the model"),
    ) -> Output:
        """Run a single prediction on the model"""
        # Prepare the image
        image = Image.open(str(image)).convert('RGB')
        
        # Prepare the input
        text_prompts = f"<image>{instruction}"
        input_dict = {
            'image': image,
            'text': text_prompts,
            'past_text': '',
            'mask_prompts': None,
            'tokenizer': self.tokenizer,
        }
        
        # Get model prediction
        return_dict = self.model.predict_forward(**input_dict)
        answer = return_dict["prediction"]
        
        # Handle segmentation if present
        output_path = None
        if '[SEG]' in answer:
            pred_masks = return_dict["prediction_masks"][0]
            
            # Ensure mask is in the correct format
            if isinstance(pred_masks, np.ndarray):
                binary_mask = (pred_masks > 0.5).astype('uint8') * 255
            else:
                binary_mask = (pred_masks.cpu().numpy() > 0.5).astype('uint8') * 255
            
            # Ensure mask has valid dimensions
            if binary_mask.ndim == 2:
                height, width = binary_mask.shape
            elif binary_mask.ndim == 3:
                # If we have a 3D array, take the first channel
                binary_mask = binary_mask[0] if binary_mask.shape[0] == 1 else binary_mask[:, :, 0]
                height, width = binary_mask.shape
            else:
                return Output(img=None, response=str(answer))
                
            # Check if dimensions are valid and mask is not empty
            if width > 0 and height > 0 and np.any(binary_mask):
                # Create output directory if it doesn't exist
                os.makedirs("/tmp", exist_ok=True)
                
                # Save the binary mask
                output_path = "/tmp/output.png"
                if cv2.imwrite(output_path, binary_mask):
                    return Output(img=Path(output_path), response=str(answer))

        return Output(img=None, response=str(answer))