import argparse import os from pathlib import Path from typing import Union # hf path import requests import torch from PIL import Image from transformers import AutoProcessor from transformers import AutoTokenizer import soundfile as sf import io import numpy as np import scipy.signal as signal from examples.mimo.model_providers.llava_avlm import model_provider_llava_avlm from megatron.core import dist_checkpointing, parallel_state, tensor_parallel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.training import print_rank_0 from examples.mimo.data.utils.calculate_audio_tokens import calculate_num_audio_tokens def init_distributed(tp_size: int = 1, pp_size: int = 1): if torch.distributed.is_initialized(): return rank = int(os.environ.get("LOCAL_RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) torch.cuda.set_device(rank % torch.cuda.device_count()) torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) parallel_state.initialize_model_parallel(tp_size, pp_size) def get_input_data( processor: AutoProcessor, image_processor: AutoProcessor, audio_processor: AutoProcessor, audio_path: str, image_path: str, prompt: str, device: Union[int, str] = 0): """ Prepare inputs for the MIMO model forward pass. """ def read_audio(audio_path): """Process audio file and return tensor.""" with open(audio_path, 'rb') as f: audio_bytes = f.read() audio_io = io.BytesIO(audio_bytes) waveform, sample_rate = sf.read(audio_io) # Resample if needed fixed_sample_rate = 16000 if sample_rate != fixed_sample_rate: num_samples = int(len(waveform) * fixed_sample_rate / sample_rate) waveform = signal.resample(waveform, num_samples) # Convert to tensor audio_tensor = torch.from_numpy(waveform).float() return audio_tensor def read_image(image_path): """Process image file and return tensor.""" with open(image_path, 'rb') as f: image_bytes = f.read() image_io = io.BytesIO(image_bytes) image = Image.open(image_io) image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1) # Convert to CxHxW format image_tensor = image_tensor.float() / 255.0 # rescale to [0,1] range return image_tensor # read audio and image audio_tensor = read_audio(audio_path) image_tensor = read_image(image_path) # set up prompt conversation = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, ], } ] prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) # process audio processed_audios = audio_processor(audio_tensor, sampling_rate=16000) processed_audios = torch.tensor(processed_audios["input_features"]) processed_audios = processed_audios.squeeze(0) # remove batch dim num_audio_tokens = calculate_num_audio_tokens(audio_tensor.unsqueeze(0), "openai/whisper-base") audios_seq_lengths = torch.tensor(num_audio_tokens) prompt = prompt.replace("