run_text_generation.py 33.4 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
"""Generate text using a vision language model."""
import json
import logging
import os
import sys
from functools import partial
from typing import List

# Add megatron to the path.
sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
)

import torch
import yaml
from config import EvaluationConfig
from evaluation.evaluation_datasets import get_evaluation_dataset
from model import model_provider
from multimodal_args import add_multimodal_extra_args

from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.inference.text_generation.api import generate_and_post_process
from megatron.inference.text_generation.forward_step import ForwardStep
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.engines import StaticInferenceEngine
from megatron.core.inference.inference_request import InferenceRequest, VLMInferenceRequest
from megatron.core.inference.text_generation_controllers.vlm_text_generation_controller import (
    VLMTextGenerationController,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
    InferenceWrapperConfig,
)
from megatron.core.inference.model_inference_wrappers.multimodal.vlm_inference_wrapper import (
    VLMInferenceWrapper,
)
from megatron.training import get_args, get_model, get_tokenizer, print_rank_0, get_tensorboard_writer, is_last_rank
from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron


def add_text_generation_args(parser):
    """Text generation arguments."""
    group = parser.add_argument_group(title='Vision language model text generation arguments')

    group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.')
    group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.')
    group.add_argument("--top_k", type=int, default=0, help='Top k sampling.')
    group.add_argument(
        "--out-seq-length", type=int, default=128, help='Length of the output generated text.'
    )
    group.add_argument("--output-path", type=str, help='Output file path')
    group.add_argument('--input-image-path', type=str, help="Input image directory")
    group.add_argument(
        '--num-partitions', type=int, default=0, help="Number of partitions for inputs."
    )
    group.add_argument('--partition-id', type=int, default=0, help="Partition index")
    group.add_argument("--gt-path", type=str, help="Optional ground truth file")
    group.add_argument(
        "--task",
        type=str,
        choices=[
            "captioning",
            "TextVQA",
            "VQAv2",
            "ChartQA",
            "MMMU",
            "OCRBench",
            "OCRBench_v2",
            "MathVista",
            "AI2D",
            "InfoVQA",
            "SPDocVQA",
            "RD_TableBench",
            "VideoMME",
            "PerceptionTest",
            "MotionBench",
            "PhysGameBench",
            "MVBench",
        ],
        help="Generation task to run",
    )
    group.add_argument(
        "--num-samples-per-partition", type=int, default=0, help="Number of samples per partition"
    )
    group.add_argument("--config-path", type=str, help="Evaluation config file to use.")

    # Add common multimodal arguments needed for e.g. building the model.
    parser = add_multimodal_extra_args(parser)

    return parser


def get_evaluation_dataloader(
    task,
    input_image_path,
    gt_path,
    img_h,
    img_w,
    use_tiling,
    max_num_tiles,
    use_thumbnail,
    num_samples_per_partition,
    num_partitions,
    partition_id,
    num_frames,
    num_workers,
    vision_model_type,
    split="validation"
):
    """Build evaluation dataset."""
    dataset = get_evaluation_dataset(
        task,
        input_image_path,
        gt_path,
        img_h,
        img_w,
        use_tiling,
        max_num_tiles,
        use_thumbnail,
        num_samples_per_partition,
        num_partitions,
        partition_id,
        num_frames,
        vision_model_type,
        split=split
    )

    dp_rank = parallel_state.get_data_parallel_rank()
    dp_world_size = parallel_state.get_data_parallel_world_size()

    sampler = torch.utils.data.DistributedSampler(
        dataset, shuffle=False, num_replicas=dp_world_size, rank=dp_rank
    )
    # TODO: Batched inference is not supported yet.
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=None, num_workers=num_workers, sampler=sampler, pin_memory=True
    )

    return dataloader


def generate_samples(model, config: EvaluationConfig, print_output):
    """Text generation using a trained vision language model."""
    args = get_args()

    dataloader = get_evaluation_dataloader(
        config.task,
        config.input_image_path,
        config.gt_path,
        args.img_h,
        args.img_w,
        args.use_tiling,
        args.max_num_tiles,
        args.use_thumbnail,
        config.num_samples_per_partition,
        config.num_partitions,
        config.partition_id,
        args.num_frames,
        args.num_workers,
        args.vision_model_type,
        config.split
    )

    num_img_embeddings_per_tile = get_num_image_embeddings(
        args.img_h,
        args.img_w,
        args.patch_dim,
        args.vision_model_type,
        args.disable_vision_class_token,
        1,
        args.pixel_shuffle,
        args.use_tile_tags,
        args.max_num_tiles,
        args.tokenizer_prompt_format,
    )

    if args.use_mcore_inference:
        inference_wrapper_config = InferenceWrapperConfig(
            hidden_size=args.hidden_size,
            inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold,
            fp32_residual_connection=args.fp32_residual_connection,
            params_dtype=args.params_dtype,
            padded_vocab_size=args.padded_vocab_size,
        )
        inference_wrapped_model = VLMInferenceWrapper(model, inference_wrapper_config)
        tokenizer = get_tokenizer()
        controller = VLMTextGenerationController(
            inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
        )
        inference_engine = StaticInferenceEngine(
            controller, max_batch_size=1, random_seed=args.seed
        )
        sampling_params = SamplingParams(
            temperature=config.temperature,
            top_k=config.top_k,
            top_p=config.top_p,
            num_tokens_to_generate=config.out_seq_length,
        )

    for idx, (imgs, num_tiles, sample_id, question, answers, metadata) in enumerate(dataloader):
        imgs = imgs.to("cuda")
        num_tiles = num_tiles.to("cuda")

        conv = get_conversation(config.task, question, metadata)

        if not args.use_mcore_inference:
            forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles, args.decoder_seq_length)


        if is_first_rank():

            if args.use_mcore_inference:
                inference_request = VLMInferenceRequest(
                   request_id=inference_engine.get_new_request_id(),
                   prompt=conv,
                   prompt_tokens=controller.tokenize_prompt(conv),
                   sampling_params=sampling_params,
                   num_img_embeddings_per_tile=num_img_embeddings_per_tile,
                   imgs=imgs,
                   num_tiles=num_tiles,
                   decoder_seq_length=args.decoder_seq_length,
                )
                results: List[InferenceRequest] = inference_engine.generate(
                    inference_requests=[inference_request]
                )

                resp_sentences = [
                    tokenizer.detokenize(result.prompt_tokens) + result.generated_text
                    for result in results
                ]
            else:
                resp_sentences, _, _, _ = generate_and_post_process(
                    model,
                    forward_step=forward_step,
                    prompts=[conv],
                    tokens_to_generate=config.out_seq_length,
                    top_k_sampling=config.top_k,
                    top_p_sampling=config.top_p,
                    add_BOS=False,
                    temperature=config.temperature,
                    random_seed=args.seed,
                    detokenize_segments=False,
                    data_parallel=True,
            )

            for generation in resp_sentences:
                if isinstance(sample_id, torch.Tensor):
                    sample_id = sample_id.item()

                output = {"sample_id": sample_id}

                output_name = ""
                if config.task == "captioning":
                    output_name = "caption"
                elif config.task in (
                    "TextVQA",
                    "VQAv2",
                    "ChartQA",
                    "OCRBench",
                    "MathVista",
                    "AI2D",
                    "RealworldQA",
                    "MotionBench",
                    "PhysGameBench",
                    "MVBench",
                    "InfoVQA",
                    "SPDocVQA",
                ):
                    output_name = "answer"
                elif config.task in ("MMMU"):
                    output_name = "text"
                elif config.task == "VideoMME":
                    output_name = "response"
                    output = question
                elif config.task in ["OCRBench_v2", "RD_TableBench"]:
                    output_name = "predict"
                else:
                    raise NotImplementedError("no output name defined for", config.task)

                prompt, generated = get_prompt_and_generated(
                    generation, args.tokenizer_prompt_format
                )
                if config.task == "VideoMME":
                    output["questions"][0][output_name] = generated
                else:
                    output["prompt"] = prompt
                    output[output_name] = generated

                if config.task in ["captioning", "RD_TableBench"]:
                    output["ground_truth"] = answers
                elif config.task in (
                    "TextVQA",
                    "VQAv2",
                    "ChartQA",
                    "OCRBench",
                    "OCRBench_v2",
                    "MathVista",
                    "AI2D",
                    "PerceptionTest",
                    "RealworldQA",
                    "MotionBench",
                    "PhysGameBench",
                    "MVBench",
                    "InfoVQA",
                    "SPDocVQA",
                ):
                    if isinstance(answers, str):
                        answers = [answers]
                    output["gt_answer"] = answers

                    if len(metadata) > 0:
                        output.update(metadata)
                elif config.task == "MMMU":
                    output["prediction"] = generated
                    output.update(metadata)
                elif config.task == "VideoMME":
                    pass
                else:
                    raise NotImplementedError("no output processing defined for", config.task)

                if print_output:
                    print(output)

                yield output
                idx += 1
        else:
            if args.use_mcore_inference:
                inference_request = VLMInferenceRequest(
                   request_id=inference_engine.get_new_request_id(),
                   prompt=conv,
                   prompt_tokens=controller.tokenize_prompt(conv),
                   sampling_params=sampling_params,
                   num_img_embeddings_per_tile=num_img_embeddings_per_tile,
                   imgs=imgs,
                   num_tiles=num_tiles,
                   decoder_seq_length=args.decoder_seq_length,
                )
                inference_engine.generate(
                    inference_requests=[inference_request]
                )
            else:
                generate_and_post_process(
                    model, forward_step=forward_step, detokenize_segments=False, data_parallel=True
                )

            idx += 1


def get_evaluation_config():
    """Get evaluation config from a config file or command-line arguments."""
    args = get_args()
    if args.config_path:
        with open(args.config_path, "r") as f:
            config_dict = yaml.safe_load(f)

        config = EvaluationConfig(**config_dict)
    else:
        config = EvaluationConfig(
            task=args.task,
            temperature=args.temperature,
            top_p=args.top_p,
            top_k=args.top_k,
            out_seq_length=args.out_seq_length,
            output_path=args.output_path,
            input_image_path=args.input_image_path,
            gt_path=args.gt_path,
            num_partitions=args.num_partitions,
            partition_id=args.partition_id,
            num_samples_per_partition=args.num_samples_per_partition,
        )

    # Default output path if not defined...
    if not config.output_path:
        os.makedirs("generated", exist_ok=True)
        config.output_path = "generated/" + args.language_model_type

    return config


def get_batch_evaluation_configs():
    """Get evaluation config from a config file containing batch evaluation configs."""
    args = get_args()
    if args.config_path:
        with open(args.config_path, "r") as f:
            config_dict = yaml.safe_load(f)['datasets']

        configs = {}

        for key, value in config_dict.items():
            configs[key] = EvaluationConfig(**value)
            configs[key].dataset = key

            # Default output path if not defined... use args.output_path
            if not configs[key].output_path:
                os.makedirs(args.output_path, exist_ok=True)
                configs[key].output_path = args.output_path + args.language_model_type + "-" + key

    else:
        print("No config path provided")
        sys.exit(1)

    return configs

def is_first_rank():
    """First tensor and pipeline parallel rank."""
    return (
        parallel_state.is_pipeline_first_stage(ignore_virtual=True)
        and parallel_state.get_tensor_model_parallel_rank() == 0
    )


def get_output_path(config, dp_rank):
    """Generation output path."""
    try:
        args = get_args()
        if args.ckpt_step is not None:
            return f"{config.output_path}-{config.task}-dprank={dp_rank}-partition={config.partition_id}-step={args.ckpt_step}.jsonl"
    except:
        return f"{config.output_path}-{config.task}-dprank={dp_rank}-partition={config.partition_id}.jsonl"


def generate_and_write_samples(model, config, print_output=True):
    """Generate text and write to an output file."""
    dp_rank = parallel_state.get_data_parallel_rank()

    if is_first_rank():
        output_path = get_output_path(config, dp_rank)
        output_file = open(output_path, "w")
        print(f"output path: {output_file.name}")

    with torch.no_grad():
        for output in generate_samples(model, config, print_output):
            if is_first_rank():
                output_file.write(json.dumps(output) + "\n")
                output_file.flush()

    if is_first_rank():
        output_file.close()

class VLMForwardStep(ForwardStep):
    """Inference forward step for a multimodal model."""

    def __init__(
        self,
        num_img_embeddings_per_tile,
        images,
        num_tiles,
        decoder_seq_length,
        model,
        inference_context,
    ):
        """Create multimodal forward step."""
        total_num_tiles = torch.sum(num_tiles).item()
        num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles

        super().__init__(model, inference_context)
        self._images = images
        self._num_tiles = num_tiles
        self._num_img_embeddings = num_img_embeddings
        self.decoder_seq_length = decoder_seq_length

        self._recv_only_vision_embeds = False
        pp_rank = parallel_state.get_pipeline_model_parallel_rank()
        # Checks if the previous stage only has a vision encoder, and that the current stage has part of the LM decoder.
        # In this case, the current stage should only receive vision embeddings.
        if pp_rank > 0:
            self._recv_only_vision_embeds = parallel_state.is_inside_encoder(pp_rank - 1) and (not parallel_state.is_inside_decoder(pp_rank - 1)) and parallel_state.is_inside_decoder()

        # Checks if the current stage only has a vision encoder
        self._encoder_only = parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder()

    def _forward(self, tokens, position_ids, attention_mask):
        return self.model(
            self._images,
            tokens,
            position_ids,
            attention_mask=None,
            inference_context=self.inference_context,
            num_image_tiles=self._num_tiles,
            runtime_gather_output=True,
        )

    def __call__(self, tokens, position_ids, attention_mask):
        num_image_tokens = (tokens == self.model.module.image_token_index).sum().item()
        num_tokens = tokens.size(1)
        recv_buffer_seq_length = None
        if num_image_tokens > 0:
            # When there are image tokens and this stage only receives vision embeddings, adjust the recv buffer seq length to match the image embeddings sequence length.
            # If there are image tokens and this stage receives full embeddings, make sure we compensate for expansion of image tokens.
            # Note that this will set a recv_buffer_seq_length for the encoder stage, this length is irrelevant since that recv buffer is never allocated.
            if self._recv_only_vision_embeds:
                recv_buffer_seq_length = self._num_img_embeddings
            else:
                recv_buffer_seq_length = min(self._num_img_embeddings + num_tokens - num_image_tokens, self.decoder_seq_length)
        elif self._recv_only_vision_embeds:
            # If this stage only receives vision embeddings and there are no image tokens we won't run the encoder and therefore shouldn't try to recv.
            recv_buffer_seq_length = 0

        # If the pipeline stage only has a vision encoder, then it only needs to run when there are image tokens
        if not (self._encoder_only and num_image_tokens == 0):
            output = super().__call__(tokens, position_ids, attention_mask, recv_buffer_seq_length=recv_buffer_seq_length)
        else:
            output = None
        if isinstance(output, tuple):
            logits, _ = output
        else:
            logits = output

        # On the first inference iteration, we compute image tokens.
        # On every PP stage(although inference params should only matter for decoder),
        # update the sequence length offset by the number of image tokens.
        if num_tokens > 1 and num_image_tokens > 0:
            if "image_tokens_count" not in self.inference_context.key_value_memory_dict:
                self.inference_context.key_value_memory_dict["image_tokens_count"] = self._num_img_embeddings

            if self._num_img_embeddings + num_tokens - num_image_tokens > self.decoder_seq_length:
                self.inference_context.sequence_len_offset += self.decoder_seq_length - num_tokens
            else:
                self.inference_context.sequence_len_offset += (
                    self.inference_context.key_value_memory_dict["image_tokens_count"] - num_image_tokens
                )

        return logits


def get_conversation(task, question, metadata=None):
    """Get a conversation for a given task and evaluation question."""
    conversation = []

    # In all cases, the tokenizer adds possible header tokens for the assistant.
    if task == "captioning":
        conversation = [
            {"role": "system", "content": "Answer the questions."},
            {
                "role": "user",
                "content": f"{IMAGE_TOKEN}\nGive a brief description of this image in one sentence.",
            },
        ]
    elif task in ("TextVQA", "InfoVQA", "SPDocVQA"):
        conversation = [
            {"role": "system", "content": "Follow the user's instruction and answer questions."},
            {
                "role": "user",
                "content": f"{IMAGE_TOKEN}\n{question}\nAnswer the question using a single word, phrase, or number.",
            },
        ]
    elif task == "VQAv2":
        conversation = [
            {"role": "system", "content": "Follow the user's instruction and answer questions."},
            {
                "role": "user",
                "content": f"{IMAGE_TOKEN}\n{question}\nAnswer the question using a single word or phrase.",
            },
        ]
    elif task == "ChartQA":
        conversation = [
            {"role": "system", "content": "Follow the user's instruction and answer questions."},
            {
                "role": "user",
                "content": f"{IMAGE_TOKEN}\n{question}\nAnswer the question using a single word or phrase.",
            },
        ]
    elif task == "MMMU":
        conversation = [
            {"role": "system", "content": "Answer the questions."},
            {"role": "user", "content": question},
        ]
    elif task == "VideoMME":
        q = (
            "Select the best answer to the following multiple-choice "
            "question based on the video. Respond with only the letter "
            "(A, B, C, or D) of the correct option.\n"
        )
        q += question["questions"][0]["question"] + "\n"
        q += question["questions"][0]["choices"][0] + "\n"
        q += question["questions"][0]["choices"][1] + "\n"
        q += question["questions"][0]["choices"][2] + "\n"
        q += question["questions"][0]["choices"][3] + "\n"

        conversation = [
            {"role": "system", "content": "Answer the questions."},
            {"role": "user", "content": f"{IMAGE_TOKEN}\n{q}"},
        ]
    elif task in ("OCRBench", "OCRBench_v2", "RD_TableBench"):
        conversation = [
            {"role": "system", "content": "Follow the user's instruction and answer questions."},
            {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"},
        ]
    elif task == "MathVista":
        conversation = [
            {"role": "system", "content": "You are math expert. Use your math knowledge to calculate the answer."},
            {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"},
        ]
    elif task == "RealworldQA":
        conversation = [
            {"role": "system", "content": "Follow the user's instruction and answer questions."},
            {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"},
        ]
    elif task == "AI2D":
        conversation = [
            {"role": "system", "content": "Follow the user's instruction and answer questions."},
            {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"},
        ]
    elif task == "MotionBench":
        extra_instruction = "Respond with only the letter choice (A, B, C, or D) of the correct option.\n"
        conversation = [
            {"role": "system", "content": "Answer the questions."},
            {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}\n{extra_instruction}"},
        ]
    elif task == "PhysGameBench":
        extra_instruction = "Respond with only the letter choice (A, B, C, or D) of the correct option.\n"
        conversation = [
            {"role": "system", "content": "Answer the questions."},
            {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}\n{extra_instruction}"},
        ]
    elif task == "MVBench":
        conversation = [
            {"role": "system", "content": "Answer the questions."},
            {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}\nAnswer the question using a single word or phrase."},
        ]
    elif task in ["PerceptionTest"]:
        conversation = [
            {"role": "system", "content": "Answer the questions."},
            {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"},
        ]
    else:
        raise NotImplementedError(f"No prompting support for task {task}")

    return conversation


def get_prompt_and_generated(prompt_and_generation, prompt_format):
    """Strip prompt and other unnecessary text from generation."""
    if prompt_format in ("llama3", "llama3p1"):
        splitted = prompt_and_generation.split("<|start_header_id|>assistant<|end_header_id|>\n\n")
        prompt = splitted[0]
        generated = splitted[1]
        generated = generated.split("<|eot_id|>")[0]
    elif prompt_format == "mistral":
        splitted = prompt_and_generation.split("[/INST]")
        prompt = splitted[0]
        generated = splitted[1]
        generated = generated.split("</s>")[0]
    elif prompt_format == "chatml":
        splitted = prompt_and_generation.split("<|im_start|> assistant\n")
        prompt = splitted[0]
        generated = splitted[1]
        generated = generated.split("<|im_end|>")[0]
    elif prompt_format in ("nvlm-yi-34b", "qwen2p0", "qwen2p5"):
        splitted = prompt_and_generation.split("<|im_start|>assistant\n")
        prompt = splitted[0]
        generated = splitted[1]
        generated = generated.split("<|im_end|>")[0]
    elif prompt_format in ("nemotron5"):
        splitted = prompt_and_generation.split("<SPECIAL_14>assistant\n")
        prompt = splitted[0]
        generated = splitted[1]
        generated = generated.split("<SPECIAL_15>")[0]
    elif prompt_format in ("nemotron5-aligned"):
        splitted = prompt_and_generation.split("Assistant\n")
        prompt = splitted[0]
        generated = splitted[1]
        generated = generated.split("[PREFIX]")[0]
        generated = generated.split("\\n")[0]
    else:
        raise ValueError(f"Prompt format {prompt_format} is not supported.")

    # Remove possible garbage.
    generated = generated.strip()

    return prompt, generated



def run_eval(config, iteration=None):
    # Run evaluation.
    print(f"====== {config.task} {config.dataset} at iteration={iteration} scores ======")

    if config.task == "TextVQA":
        from evaluation.evaluate_textvqa import textvqa_eval
        avg_acc = textvqa_eval(config.output_path)

        score = {"TextVQA accuracy": avg_acc}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} at iteration={iteration} TextVQA accuracy: {score}\n")

    elif config.task == "OCRBench":
        from evaluation.evaluate_ocrbench import ocrbench_eval
        log, avg_acc = ocrbench_eval(config.output_path)

        score = {"OCRBench accuracy": avg_acc}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} at iteration={iteration} OCRBench accuracy: {score}\n")
            f.write(f"{log}\n")

    elif config.task == "MathVista":
        from evaluation.evaluate_mathvista import mathvista_eval
        avg_acc = mathvista_eval(config.output_path)

        score = {"MathVista accuracy": avg_acc}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} at iteration={iteration} MathVista accuracy: {score}\n")

    elif config.task == "ChartQA":
        from evaluation.evaluate_chartqa import chartqa_eval
        avg_acc = chartqa_eval(config.output_path)

        score = {"ChartQA accuracy": avg_acc}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} at iteration={iteration} ChartQA accuracy: {score}\n")

    elif config.task == "SPDocVQA":
        from evaluation.evaluate_spdocvqa import spdocvqa_eval
        avg_acc = spdocvqa_eval(config.output_path)

        score = {"SPDocVQA accuracy": avg_acc}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} at iteration={iteration} SPDocVQA accuracy: {score}\n")

    elif config.task == "RealworldQA":
        from evaluation.evaluate_realworldqa import realworldqa_eval
        avg_acc = realworldqa_eval(config.output_path)

        score = {"RealworldQA accuracy": avg_acc}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} at iteration={iteration} RealworldQA accuracy: {score}\n")

    elif config.task == "AI2D":
        from evaluation.evaluate_ai2d import ai2d_eval
        avg_acc = ai2d_eval(config.output_path)

        score = {f"AI2D {config.dataset} accuracy": avg_acc}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} at iteration={iteration} AI2D accuracy: {score}\n")

    elif config.task == "MMMU":
        from evaluation.evaluate_mmmu import convert_to_mmmu_format
        from examples.multimodal.evaluation.mmmu_utils import mmmu_main_eval
        result_file = convert_to_mmmu_format(config.output_path)
        result = json.load(open(result_file))
        mmmu_results = mmmu_main_eval(result, {"answer_dict": config.gt_path})
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.split} at iteration={iteration} :\n")
            for cat, cat_val in mmmu_results.items():
                if 'Overall' in cat:
                    cat = cat.replace("Overall-", "")
                    print(f'{cat}: {cat_val["acc"] * 100:.2f}')
                    f.write(f'{cat}: {cat_val["acc"] * 100:.2f}\n')

        score = {"MMMU val accuracy": mmmu_results['Overall']['acc']}
    elif config.task == 'captioning':
        from evaluation.evaluate_coco import coco_captioning_eval
        cider_score = coco_captioning_eval(config.output_path, config.gt_path)
        score = {f"{config.task} {config.dataset} CIDEr": cider_score}

        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} CIDEr scores at iteration={iteration}: {cider_score}\n")
    elif config.task == 'MotionBench':
        from evaluation.evaluate_video_motionbench import motionbench_eval
        avg_acc = motionbench_eval(config.output_path)

        score = {f"MotionBench accuracy": avg_acc}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} scores at iteration={iteration}: {score}\n")
    elif config.task == 'PhysGameBench':
        from evaluation.evaluate_video_phys_game_bench import phys_game_bench_eval
        avg_acc_dict = phys_game_bench_eval(config.output_path)

        score = {f"PhysGame Total accuracy": avg_acc_dict['Physgame-Total-Acc']}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} scores at iteration={iteration}: {avg_acc_dict}\n")
    elif config.task == "MVBench":
        from evaluation.evaluate_video_mvbench import mvbench_eval
        avg_acc_dict = mvbench_eval(config.output_path)

        score = {f"MVBench accuracy": avg_acc_dict['total-acc']}
        with open(config.output_path + "-scores.txt", "a") as f:
            f.write(f"{config.task} {config.dataset} scores at iteration={iteration}: {avg_acc_dict}\n")
    else:
        raise NotImplementedError(f"online evaluation of {config.task} not implemented yet")

    print(score)
    return score



def eval_single_task():
    """Vision language model text generation for one task."""
    initialize_megatron(extra_args_provider=add_text_generation_args)

    args = get_args()

    def wrapped_model_provider(pre_process, post_process, add_encoder=True, add_decoder=True):
        return model_provider(pre_process, post_process, add_encoder=add_encoder, add_decoder=add_decoder,
                              parallel_output=False)

    # Set up model and load checkpoint.
    model = get_model(wrapped_model_provider, model_type=ModelType.encoder_and_decoder, wrap_with_ddp=False)

    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    model = model[0]

    model.eval()

    config = get_evaluation_config()

    generate_and_write_samples(model, config)

    # Make sure the first rank is done writing so that the last rank can run eval.
    torch.distributed.barrier()

    if not is_last_rank():
        return []

    run_eval(config)


def eval_batch_tasks():
    """Vision language model text generation for batch tasks."""
    initialize_megatron(extra_args_provider=add_text_generation_args)

    args = get_args()


    def wrapped_model_provider(pre_process, post_process, add_encoder=True, add_decoder=True):
        return model_provider(pre_process, post_process, add_encoder=add_encoder, add_decoder=add_decoder,
                              parallel_output=False)

    # Set up model and load checkpoint.
    model = get_model(wrapped_model_provider, model_type=ModelType.encoder_and_decoder, wrap_with_ddp=False)

    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    model = model[0]

    model.eval()

    configs = get_batch_evaluation_configs()

    for task, config in configs.items():
        print(f"Running eval task {task}")
        generate_and_write_samples(model, config)

        # Make sure the first rank is done writing so that the last rank can run eval.
        torch.distributed.barrier()

        if is_last_rank():
            # Run evaluation.
            score = run_eval(config, args.ckpt_step)
            from train import write_eval_to_tensorboard
            writer = get_tensorboard_writer()
            write_eval_to_tensorboard([score], args.ckpt_step, writer, args.ckpt_step)

        torch.distributed.barrier()


if __name__ == "__main__":
    eval_single_task()