app.pyi 7.71 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# --------------------------------------------------------
# Magma - Multimodal AI Agent at Microsoft Research
# Copyright (c) 2025 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Jianwei Yang (jianwyan@microsoft.com)
# --------------------------------------------------------

import os
import warnings
from utils.visualizer import Visualizer
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import random
import gradio as gr
import ast, re

import torch
import torchvision
from transformers import AutoModelForCausalLM, AutoProcessor

'''
build model
'''
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)

spatial_quant_size = 256

# Load AI Model
dtype = torch.bfloat16
device = "cuda"
magma_model_id = "microsoft/Magma-8B"
model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
model.to(device)

@torch.no_grad()
def inference(image, task, *args, **kwargs):
    # image = image['image']
    task_description = task
    num_marks = args[0]
    speed = args[1]
    steps = args[2]
    mark_ids = [i+1 for i in range(num_marks)]  

    image_resized = image.resize((256, 256))

    magma_template = (
        # "<image>\nThe image is labeled with numeric marks {}.\n"
        "<image>\nThe image is split into 256x256 grids and is labeled with numeric marks {}.\n"
        "The robot is doing: {}. To finish the task, how to move the numerical marks in the image with speed {} for the next {} steps?\n"
    )

    """
    Visual Trace Generation
    """
    if model.config.mm_use_image_start_end:
        magma_template = magma_template.replace("<image>", "<image_start><image><image_end>")    
    conv_user = magma_template.format(mark_ids, task_description, speed, steps)
    print(conv_user)

    convs = [
        {"role": "user", "content": conv_user},
    ]
    convs = [
        {
            "role": "system",
            "content": "You are agent that can see, talk and act.", 
        },            
    ] + convs     

    prompt = processor.tokenizer.apply_chat_template(
        convs,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = processor(images=image_resized, texts=prompt, return_tensors="pt")
    inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
    inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)    
    inputs = inputs.to(dtype).to(device)

    with torch.inference_mode():
        output_ids = model.generate(
            **inputs,
            temperature=0.3,
            do_sample=True,
            num_beams=1,
            max_new_tokens=1024,
            use_cache=True,
        )
    response = processor.batch_decode(output_ids, skip_special_tokens=True)[0]

    if len(response)==0:
        return None
    # extract traces from response
    if "and their future positions are:" in response:
        selected_marks_str, traces_str = response.split("and their future positions are:\n")
    else:
        selected_marks_str, traces_str = None, response

    try:
        traces_dict = ast.literal_eval('{' + traces_str.strip().replace('\n\n',',') + '}')
        overlay_traces = []
        for mark_id, trace in traces_dict.items():
            # convert list of tuples to tensor
            trace = torch.tensor(ast.literal_eval(trace)).unsqueeze(1)
            overlay_traces.append(trace)
        # padded to the same length with the last element
        max_len = max([trace.shape[0] for trace in overlay_traces])
        for i in range(len(overlay_traces)):
            if overlay_traces[i].shape[0] < max_len:
                overlay_traces[i] = torch.cat([overlay_traces[i], overlay_traces[i][-1].unsqueeze(0).repeat(max_len - overlay_traces[i].shape[0], 1, 1)], dim=0)        
        overlay_traces = torch.cat(overlay_traces, dim=1).unsqueeze(0)
        # if selected_marks_str is not None:
        #     selected_marks = re.findall(r'\[(.*?)\]', selected_marks_str)
        #     selected_marks = [torch.tensor(ast.literal_eval(mark)).unsqueeze(0) for mark in selected_marks]
        #     selected_marks = torch.cat(selected_marks, dim=0).unsqueeze(0)        
        #     overlay_traces = torch.cat([selected_marks.unsqueeze(1), overlay_traces], dim=1)
        overlay_traces = overlay_traces.float() / 256
        overlay_traces[:,:,:,0] = overlay_traces[:,:,:,0] * image.size[0]
        overlay_traces[:,:,:,1] = overlay_traces[:,:,:,1] * image.size[1]
        images = [image] * overlay_traces.shape[1]
        overlay_visibility = overlay_traces.new(overlay_traces.shape[0], overlay_traces.shape[1], overlay_traces.shape[2]).fill_(True)
        video = torch.stack([torchvision.transforms.ToTensor()(img) for img in images])[None].float()*255    
        vis = Visualizer(save_dir="./saved_videos", pad_value=0, linewidth=2, tracks_leave_trace=-1)
        vis.visualize(video, overlay_traces, overlay_visibility)
        # return video path
        return "./saved_videos/video.mp4"
    except Exception as e:
        print(e)
        return None
from gradio.events import Dependency

class ImageMask(gr.components.Image):
    """
    Sets: source="canvas", tool="sketch"
    """

    is_template = True

    def __init__(self, **kwargs):
        super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)

    def preprocess(self, x):
        return super().preprocess(x)
    from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
    from gradio.blocks import Block
    if TYPE_CHECKING:
        from gradio.components import Timer

class Video(gr.components.Video):
    """
    Sets: source="canvas", tool="sketch"
    """

    is_template = True

    def __init__(self, **kwargs):
        super().__init__(source="upload", **kwargs)

    def preprocess(self, x):
        return super().preprocess(x)
    from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
    from gradio.blocks import Block
    if TYPE_CHECKING:
        from gradio.components import Timer


'''
launch app
'''
title = "Magma"

description = '''Magma: Multimodal Agent to Act'''
'''Usage
Instructions:
&#x1F388 Try our default examples first (Sketch is not automatically drawed on input and example image);
&#x1F388 For video demo, it takes about 30-60s to process, please refresh if you meet an error on uploading;
&#x1F388 Upload an image/video (If you want to use referred region of another image please check "Example" and upload another image in referring image panel);
&#x1F388 Select at least one type of prompt of your choice (If you want to use referred region of another image please check "Example");
&#x1F388 Remember to provide the actual prompt for each promt type you select, otherwise you will meet an error (e.g., rember to draw on the referring image);
&#x1F388 Our model by default support the vocabulary of COCO 133 categories, others will be classified to 'others' or misclassifed.
'''

article = "The Demo is Run on Magma-8B."
inputs = [
    gr.components.Image(label="Draw on Image",type="pil"), 
    gr.Textbox(label="Task"),
    gr.Slider(1, 50, value=10, label="Number of Marks", info="Choose between 1 and 50"),
    gr.Slider(2, 50, value=8, label="Speed", info="Choose between 2 and 50"),
    gr.Slider(2, 50, value=8, label="Steps", info="Choose between 2 and 50"),
]
gr.Interface(
    fn=inference,
    inputs=inputs,
    outputs=[
        gr.Video(
        label="Robot planning trajectory", format="mp4"
        ),
    ],
    examples=[
    ["agents/robot_traj/sample.png", "Pick up the chip bag.", 9, 8, 8],
    ],
    title=title,
    description=description,
    article=article,
    allow_flagging='never',
    cache_examples=False,
).launch(share=True)