longcat-next_inference.py 2.45 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor

# Load model
model_name = "/home/dengjb/download/meituan-longcat/LongCat-Next/"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, fix_mistral_regex=True)
model.text_tokenizer = tokenizer # Dynamic binding
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

# Set messages
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "What book is this?<longcat_img_start>./assets/book.png<longcat_img_end>"}
]

# Apply chat-template
text_input = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
print(f"{text_input=}")

# Preprocessing
text_inputs, visual_inputs, audio_inputs = processor(text=text_input, return_tensors="pt")
text_inputs = text_inputs.to(model.device)
if visual_inputs is not None:
    visual_inputs = visual_inputs.to(model.device)
if audio_inputs is not None:
    audio_inputs = audio_inputs.to(model.device)

# AR
with torch.no_grad():
    outputs = model.generate(
        input_ids=text_inputs["input_ids"],
        visual_inputs=visual_inputs,
        audio_inputs=audio_inputs,
        return_dict_in_generate=True,
    )

# Text decoding
output_input_ids = outputs.sequences
text_output = tokenizer.decode(output_input_ids[0][len(text_inputs["input_ids"][0]):], skip_special_tokens=True)
print(f"{text_output=}")

# Images decoding
output_visual_ids = outputs.visual_ids
if output_visual_ids.size(0) > 0:
    image_path_list = model.model.decode_visual_ids_and_save(
        output_visual_ids,
        save_prefix="./output_image",
        **model.generation_config.visual_generation_config["custom_params"],
    )
    print(f"{image_path_list=}")

# Audio decoding
output_audio_text_ids = outputs.audio_text_ids
output_audio_ids = outputs.audio_ids
if output_audio_text_ids.size(-1) > 0:
    audio_text = tokenizer.decode(output_audio_text_ids[0], skip_special_tokens=True)
    print(f"{audio_text=}")
if output_audio_ids.size(0) > 0:
    audio_path_list = model.model.decode_audio_ids_and_save(
        output_audio_ids,
        save_prefix="./output_audio",
        **model.generation_config.audio_generation_config["custom_params"],
    )
    print(f"{audio_path_list=}")