import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor # Load model model_name = "/home/dengjb/download/meituan-longcat/LongCat-Next/" model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, fix_mistral_regex=True) model.text_tokenizer = tokenizer # Dynamic binding processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) # Set messages messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What book is this?./assets/book.png"} ] # Apply chat-template text_input = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) print(f"{text_input=}") # Preprocessing text_inputs, visual_inputs, audio_inputs = processor(text=text_input, return_tensors="pt") text_inputs = text_inputs.to(model.device) if visual_inputs is not None: visual_inputs = visual_inputs.to(model.device) if audio_inputs is not None: audio_inputs = audio_inputs.to(model.device) # AR with torch.no_grad(): outputs = model.generate( input_ids=text_inputs["input_ids"], visual_inputs=visual_inputs, audio_inputs=audio_inputs, return_dict_in_generate=True, ) # Text decoding output_input_ids = outputs.sequences text_output = tokenizer.decode(output_input_ids[0][len(text_inputs["input_ids"][0]):], skip_special_tokens=True) print(f"{text_output=}") # Images decoding output_visual_ids = outputs.visual_ids if output_visual_ids.size(0) > 0: image_path_list = model.model.decode_visual_ids_and_save( output_visual_ids, save_prefix="./output_image", **model.generation_config.visual_generation_config["custom_params"], ) print(f"{image_path_list=}") # Audio decoding output_audio_text_ids = outputs.audio_text_ids output_audio_ids = outputs.audio_ids if output_audio_text_ids.size(-1) > 0: audio_text = tokenizer.decode(output_audio_text_ids[0], skip_special_tokens=True) print(f"{audio_text=}") if output_audio_ids.size(0) > 0: audio_path_list = model.model.decode_audio_ids_and_save( output_audio_ids, save_prefix="./output_audio", **model.generation_config.audio_generation_config["custom_params"], ) print(f"{audio_path_list=}")