# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """ This script provides a basic training loop for MIMO models. """ import os import sys from functools import partial from typing import Any, Dict, Iterator import torch from megatron.core.parallel_state import ( get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_src_rank, ) # Add the parent directory to the path to import from megatron sys.path.append( os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) ) from data.energon_avlm_task_encoder import llava_avlm_dataloader_provider from data.energon_vlm_task_encoder import llava_vlm_dataloader_provider from data.mock import ( train_valid_test_datasets_provider as mock_train_valid_test_datasets_provider, ) from model_providers.llava_avlm import model_provider_llava_avlm from model_providers.llava_vlm import model_provider_llava_vlm from model_providers.mock import model_provider_mock_vlm_single_encoder from utils.data_helpers import broadcast_nested_data_batch from megatron.core.enums import ModelType from megatron.training import get_args, pretrain _MODEL_PROVIDERS = { "mock": model_provider_mock_vlm_single_encoder, "llava_vlm": model_provider_llava_vlm, "video_llava_vlm": partial(model_provider_llava_vlm, is_video_input=True), "llava_avlm": model_provider_llava_avlm, } _DATASET_PROVIDERS = { "mock": mock_train_valid_test_datasets_provider, "llava_vlm": llava_vlm_dataloader_provider, "video_llava_vlm": partial(llava_vlm_dataloader_provider, is_video_input=True), "llava_avlm": llava_avlm_dataloader_provider, } def add_mimo_args(parser): """Add MIMO-specific arguments to the parser.""" group = parser.add_argument_group('MIMO', 'MIMO specific arguments') # MIMO-specific parameters group.add_argument('--dataset-provider', type=str, default='mock', help='Dataset provider to choose from [mock, llava_vlm, video_llava_vlm, llava_avlm]') group.add_argument('--model-provider', type=str, default='mock', help='Model provider to choose from [mock, llava_vlm, video_llava_vlm, llava_avlm]') # mock dataloader related args # can control mock samples with total seq length and image seq length group.add_argument('--image-size', type=int, default=224, help='Image size for vision encoder') group.add_argument('--total-seq-length', type=int, default=512, help='Total sequence length') group.add_argument('--pad-token-id', type=int, default=0, help='Padding token ID') group.add_argument('--image-token-id', type=int, default=32000, help='Image token ID') group.add_argument( '--image-seq-length', type=int, default=197, help='Number of image tokens to pad' ) group.add_argument( '--audio-encoder-model', type=str, default=None, help='Audio encoder model name' ) group.add_argument( '--hf-assign-unused-tokens', type=str, nargs='+', default=None, help='Assigning unused tokens to special tokens. Example: ' '--hf-assign-unused-tokens "