diff --git a/.coveragerc b/.coveragerc
old mode 100755
new mode 100644
diff --git a/.gitignore b/.gitignore
old mode 100755
new mode 100644
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
old mode 100755
new mode 100644
diff --git a/CHANGELOG.md b/CHANGELOG.md
old mode 100755
new mode 100644
diff --git a/CODEOWNERS b/CODEOWNERS
old mode 100755
new mode 100644
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
old mode 100755
new mode 100644
diff --git a/Dockerfile.ci.dev b/Dockerfile.ci.dev
old mode 100755
new mode 100644
diff --git a/Dockerfile.ci.lts b/Dockerfile.ci.lts
old mode 100755
new mode 100644
diff --git a/Dockerfile.linting b/Dockerfile.linting
old mode 100755
new mode 100644
diff --git a/GPT_pretraining.sh b/GPT_pretraining.sh
old mode 100755
new mode 100644
diff --git a/LICENSE b/LICENSE
old mode 100755
new mode 100644
diff --git a/Llama_pretraining.sh b/Llama_pretraining.sh
old mode 100755
new mode 100644
diff --git a/MANIFEST.in b/MANIFEST.in
old mode 100755
new mode 100644
diff --git a/README.md.origin b/README.md.origin
old mode 100755
new mode 100644
diff --git a/docs/llama_mistral.md b/docs/llama_mistral.md
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/context_parallel.rst b/docs/source/api-guide/context_parallel.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/datasets.rst b/docs/source/api-guide/datasets.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/dist_checkpointing.rst b/docs/source/api-guide/dist_checkpointing.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/dist_checkpointing.strategies.rst b/docs/source/api-guide/dist_checkpointing.strategies.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/dist_optimizer.md b/docs/source/api-guide/dist_optimizer.md
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/distributed.rst b/docs/source/api-guide/distributed.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/encoder_decoder_parallelism.rst b/docs/source/api-guide/encoder_decoder_parallelism.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/fusions.rst b/docs/source/api-guide/fusions.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/index.rst b/docs/source/api-guide/index.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/models.bert.rst b/docs/source/api-guide/models.bert.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/models.gpt.rst b/docs/source/api-guide/models.gpt.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/models.rst b/docs/source/api-guide/models.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/models.t5.rst b/docs/source/api-guide/models.t5.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/moe.rst b/docs/source/api-guide/moe.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/num_microbatches_calculator.rst b/docs/source/api-guide/num_microbatches_calculator.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/optimizer_param_scheduler.rst b/docs/source/api-guide/optimizer_param_scheduler.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/pipeline_parallel.rst b/docs/source/api-guide/pipeline_parallel.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/tensor_parallel.rst b/docs/source/api-guide/tensor_parallel.rst
old mode 100755
new mode 100644
diff --git a/docs/source/api-guide/transformer.rst b/docs/source/api-guide/transformer.rst
old mode 100755
new mode 100644
diff --git a/docs/source/images/context_parallel/CP_overview.png b/docs/source/images/context_parallel/CP_overview.png
old mode 100755
new mode 100644
diff --git a/docs/source/images/context_parallel/CP_results.png b/docs/source/images/context_parallel/CP_results.png
old mode 100755
new mode 100644
diff --git a/docs/source/images/distrib_optimizer/data_flow.png b/docs/source/images/distrib_optimizer/data_flow.png
old mode 100755
new mode 100644
diff --git a/docs/source/images/distrib_optimizer/sharding_scheme.png b/docs/source/images/distrib_optimizer/sharding_scheme.png
old mode 100755
new mode 100644
diff --git a/docs/source/images/moe/token_drop.png b/docs/source/images/moe/token_drop.png
old mode 100755
new mode 100644
diff --git a/docs/source/index.rst b/docs/source/index.rst
old mode 100755
new mode 100644
diff --git a/docs/source/user-guide/index.rst b/docs/source/user-guide/index.rst
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/README.md b/examples/academic_paper_scripts/detxoify_lm/README.md
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/annotations/filter-selfgeneration.py b/examples/academic_paper_scripts/detxoify_lm/annotations/filter-selfgeneration.py
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/annotations/perspective_api_annotate.py b/examples/academic_paper_scripts/detxoify_lm/annotations/perspective_api_annotate.py
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/annotations/preprocess.sh b/examples/academic_paper_scripts/detxoify_lm/annotations/preprocess.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/finetune_gpt.py b/examples/academic_paper_scripts/detxoify_lm/finetune_gpt.py
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/finetune_gpt_distributed-1.3b.sh b/examples/academic_paper_scripts/detxoify_lm/finetune_gpt_distributed-1.3b.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/generate-1.3b.sh b/examples/academic_paper_scripts/detxoify_lm/generate-1.3b.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/generate_samples_gpt.py b/examples/academic_paper_scripts/detxoify_lm/generate_samples_gpt.py
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/perspective_api.py b/examples/academic_paper_scripts/detxoify_lm/perspective_api.py
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/detxoify_lm/self_generation/selfgenerate-1.3b-unconditional.sh b/examples/academic_paper_scripts/detxoify_lm/self_generation/selfgenerate-1.3b-unconditional.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/msdp/README.md b/examples/academic_paper_scripts/msdp/README.md
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/msdp/data_processing.sh b/examples/academic_paper_scripts/msdp/data_processing.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/msdp/eval_knwl_generation.sh b/examples/academic_paper_scripts/msdp/eval_knwl_generation.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/msdp/eval_resp_generation.sh b/examples/academic_paper_scripts/msdp/eval_resp_generation.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/msdp/prep_resp_gen.sh b/examples/academic_paper_scripts/msdp/prep_resp_gen.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/msdp/prompt_knwl_gen.sh b/examples/academic_paper_scripts/msdp/prompt_knwl_gen.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/msdp/prompt_resp_gen.sh b/examples/academic_paper_scripts/msdp/prompt_resp_gen.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/CONFIG.sh b/examples/academic_paper_scripts/sc21/CONFIG.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/README.md b/examples/academic_paper_scripts/sc21/README.md
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/SBATCH.sh b/examples/academic_paper_scripts/sc21/SBATCH.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/SRUN.sh b/examples/academic_paper_scripts/sc21/SRUN.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/run_figure_11.sh b/examples/academic_paper_scripts/sc21/run_figure_11.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/run_figure_12.sh b/examples/academic_paper_scripts/sc21/run_figure_12.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/run_figure_13.sh b/examples/academic_paper_scripts/sc21/run_figure_13.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/run_figure_14.sh b/examples/academic_paper_scripts/sc21/run_figure_14.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/run_figure_15.sh b/examples/academic_paper_scripts/sc21/run_figure_15.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/run_figure_16.sh b/examples/academic_paper_scripts/sc21/run_figure_16.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/run_figure_17.sh b/examples/academic_paper_scripts/sc21/run_figure_17.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/run_figure_18.sh b/examples/academic_paper_scripts/sc21/run_figure_18.sh
old mode 100755
new mode 100644
diff --git a/examples/academic_paper_scripts/sc21/run_table_1.sh b/examples/academic_paper_scripts/sc21/run_table_1.sh
old mode 100755
new mode 100644
diff --git a/examples/bert/README.md b/examples/bert/README.md
old mode 100755
new mode 100644
diff --git a/examples/bert/train_bert_340m_distributed.sh b/examples/bert/train_bert_340m_distributed.sh
old mode 100755
new mode 100644
diff --git a/examples/export/README.md b/examples/export/README.md
old mode 100755
new mode 100644
diff --git a/examples/export/knowledge_distillation/pretrain_gpt_modelopt.py b/examples/export/knowledge_distillation/pretrain_gpt_modelopt.py
old mode 100755
new mode 100644
diff --git a/examples/export/ptq_and_trtllm_export/README.md b/examples/export/ptq_and_trtllm_export/README.md
old mode 100755
new mode 100644
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama2_7b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama2_7b.sh
old mode 100755
new mode 100644
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_1_8b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_1_8b.sh
old mode 100755
new mode 100644
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh
old mode 100755
new mode 100644
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_minitron_8b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_minitron_8b.sh
old mode 100755
new mode 100644
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_mistral_12b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_mistral_12b.sh
old mode 100755
new mode 100644
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_mixtral_8x7b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_mixtral_8x7b.sh
old mode 100755
new mode 100644
diff --git a/examples/export/ptq_and_trtllm_export/text_generation_ptq.py b/examples/export/ptq_and_trtllm_export/text_generation_ptq.py
old mode 100755
new mode 100644
diff --git a/examples/export/ptq_and_trtllm_export/trtllm_text_generation.py b/examples/export/ptq_and_trtllm_export/trtllm_text_generation.py
old mode 100755
new mode 100644
diff --git a/examples/export/trtllm_export/README.md b/examples/export/trtllm_export/README.md
old mode 100755
new mode 100644
diff --git a/examples/export/trtllm_export/distributed_export/gpt_distributed_gpu_export.py b/examples/export/trtllm_export/distributed_export/gpt_distributed_gpu_export.py
old mode 100755
new mode 100644
diff --git a/examples/export/trtllm_export/single_device_export/gpt_single_device_cpu_export.py b/examples/export/trtllm_export/single_device_export/gpt_single_device_cpu_export.py
old mode 100755
new mode 100644
diff --git a/examples/gpt3/README.md b/examples/gpt3/README.md
old mode 100755
new mode 100644
diff --git a/examples/gpt3/gpt_config.yaml b/examples/gpt3/gpt_config.yaml
old mode 100755
new mode 100644
index 443e4b79b88daf8d3c3b0ed0bc5cae04529db940..06257827fdfbd32d262d0da032930ebbaaf578aa
--- a/examples/gpt3/gpt_config.yaml
+++ b/examples/gpt3/gpt_config.yaml
@@ -63,6 +63,7 @@ language_model:
# MoE related
moe_router_load_balancing_type: "aux_loss"
moe_router_topk: 2
+ moe_router_topk_limited_devices: null
moe_grouped_gemm: False
moe_aux_loss_coeff: 0 # 1e-2 would be a good start value for load balance loss.
moe_z_loss_coeff: null # 1e-3 would be a good start value for z-loss
diff --git a/examples/gpt3/train_gpt3_175b_distributed.sh b/examples/gpt3/train_gpt3_175b_distributed.sh
old mode 100755
new mode 100644
diff --git a/examples/inference/README.md b/examples/inference/README.md
old mode 100755
new mode 100644
index bd8e738e55b60f38c94323a7adf445e3f7474a7e..b4b07cbc6ab88a1b3453bcecbb9534d8026a6f64
--- a/examples/inference/README.md
+++ b/examples/inference/README.md
@@ -1,5 +1,5 @@
### Megatron Core Inference Documentation
-This guide will walk you through how you can use megatron core for inference on your models.
+This guide provides an example for Megatron Core for running model inference.
### Contents
- [Megatron Core Inference Documentation](#megatron-core-inference-documentation)
@@ -18,21 +18,21 @@ This guide will walk you through how you can use megatron core for inference on
#### 1. Quick Start
-This will walk you through the flow of running batch inference on a GPT model trained using megatron core. The file can be found at [simple_gpt_batch_inference.py](./gpt/simple_gpt_batch_inference.py)
+This example runs batch inference on a GPT model trained using Megatron Core. The entrypoint is [simple_gpt_batch_inference.py](./gpt/gpt_batch_inference.py)
-##### 1.1 Understanding The Code
-***STEP 1 - We initialize model parallel and other default arguments***
-We can default micro batch size to be 1, since for TP models it is not used, and for PP models it is calculated during runtime.
+##### 1.1 Code Walkthrough
+***STEP 1 - Initialize model parallel and other default arguments***
+The micro batch size is set as 1 as it is not used in tensor-parallelism only, and for pipeline-parallel models it is calculated at runtime.
```python
initialize_megatron(
args_defaults={'no_load_rng': True, 'no_load_optim': True, 'micro_batch_size': 1}
)
```
-***STEP 2 - We load the model using the model_provider_function***
-NOTE: The model provider function in the script supports MCore and Legacy models.
+***STEP 2 - Load the model using the model_provider_function***
+NOTE: The model provider function supports both MCore and Legacy models.
```python
model = get_model(model_provider, wrap_with_ddp=False)
@@ -41,10 +41,10 @@ NOTE: The model provider function in the script supports MCore and Legacy models
```
***STEP 3 - Choose an engine***
-One of the important elements of the generate function is an inference engine. In this example we will be choosing the [megatron core engine](../../megatron/core/inference/engine/mcore_engine.py) with a [simple text generation controller](../../megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py), the default engine. Other engines that will be supported in the future are TRTLLMEngine.
+Text generation requires an inference engine, which includes a scheduler. The default engine is the [Megatron Core engine](../../megatron/core/inference/engine/mcore_engine.py) with a simple [text generation controller](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py). TRTLLMEngine will be supported in the future.
```python
inference_wrapped_model = GPTInferenceWrapper(model, args)
- text_generation_controller = SimpleTextGenerationController(
+ text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model,
tokenizer=tokenizer
)
@@ -53,12 +53,12 @@ One of the important elements of the generate function is an inference engine. I
)
```
-***STEP 4 - Run the generate function and display results***
-We use default values for the [common inference params](../../megatron/core/inference/common_inference_params.py). Customize this if you want to change top_p, top_k, number of tokens to generate etc.
-*Note that the result is returned as a list of [InferenceRequests](../../megatron/core/inference/inference_request.py)*
+***STEP 4 - Run text generation***
+The [SamplingParams](../../megatron/core/inference/sampling_params.py) contains suggested defaults. Customize this to change top_p, top_k, number of tokens to generate etc.
+*Note: The result is returned as a list of [InferenceRequests](../../megatron/core/inference/inference_request.py)*
```python
results: List[InferenceRequest] = inference_engine.generate(
- prompts=args.prompts, common_inference_params=common_inference_params
+ prompts=args.prompts, sampling_params=sampling_params
)
if torch.distributed.get_rank() == 0:
@@ -76,12 +76,12 @@ We use default values for the [common inference params](../../megatron/core/infe
##### 1.2 Running The Code
-An example run script is shown below. Change the tokenizer paths, inference params, and other settings for your model.
+An example run script is shown below. Set the tokenizer paths, inference params, and other settings appropriately.
-For a quick recap on inference params refer to [this blog](https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910)
+For a quick recap on sampling parameters, refer to [this blog](https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910).
```
-#In a slurm cluster (You could also use docker)
+# In a slurm cluster (You could also use docker)
ACCOUNT=
MLM_PATH=/path/to/megatron-lm
GPT_CKPT=/path/to/gpt/ckpt
@@ -133,8 +133,8 @@ NOTE: Other parameters which can be customized for inference are :-
--top_p (top_p sampling)
--num-tokens-to-generate (Number of tokens to generate for each prompt)
--inference-batch-times-seqlen-threshold (During inference, if batch-size times sequence-length is smaller than this threshold then we will not use pipelining, otherwise we will.')
---use-dist-ckpt (If you are using dist checkpoint format for the model)
---use-legacy-models (If you are using legacy gpt model instead of mcore gpt model)
+--use-dist-ckpt (If using dist checkpoint format for the model)
+--use-legacy-models (If using legacy gpt model instead of mcore gpt model)
```
@@ -142,16 +142,17 @@ NOTE: Other parameters which can be customized for inference are :-
-#### 2. Flow of Control In MCore Backend
-The following is what happens in the [simple_gpt_batch_inference.py](./gpt/simple_gpt_batch_inference.py).
-* We call [mcore_engine](../../megatron/core/inference/engines/mcore_engine.py) **generate()** function with all our input prompts.
-* The scheduler in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until we hit the max batch size, and then it will put the rest in the waiting requests pool.
-* The engine will then run until all requests (waiting + active) are completed
+#### 2. Control Flow in the MCore Backend
+An example of inference with static batching is provided in [gpt_batch_inference.py](./gpt/gpt_batch_inference.py).
+* [mcore_engine](../../megatron/core/inference/engines/mcore_engine.py) **generate()** function is called with the input prompts.
+* The `Scheduler` in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until max batch size is hit. Remaining requests will be added to the waiting requests pool.
+* The engine will run until all requests (waiting + active) are completed.
* The active requests are passed into **generate_all_output_tokens_static_batch()** of the text generation controller .
- * This function uses the [model_inference_wrappers](../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) **prep_model_for_inference()** , and then runs an auto regressive loop
- * In the auto regressive loop, the **get_batch_for_context_window()** method of the inference wrapper is called to get the required input, passes it into the **run_one_forward_step()** method, which calls the appropriate (PP, TP) model `.forward()` methods to get the output logits
- * The output logits are synchronized across all pipeline parallel ranks
- * The text generation controller obtains the log probabilities and samples tokens based on the strategy defined in the common inference parameters.
+ * This function uses the **prep_model_for_inference()** method of the [model_inference_wrappers](../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) and runs an autoregressive sampling loop
+ * In the autoregressive loop, the **get_batch_for_context_window()** method of the inference wrapper is called to slice out the input tokens and masks
+ * Input tokens and masks are passed it into the **run_one_forward_step()** method, which calls the model `.forward()` method to get the output logits
+ * Output logits are synchronized across all pipeline parallel ranks
+ * The text generation controller obtains the log probabilities and samples tokens based on the strategy defined in the sampling parameters.
* The sampled tokens are then appended to the input prompt tokens for the next iteration
* The **update_generation_status()** method of the text generation controller checks which prompts have finished generating or hit a stop condition
* After the inference loop, the result is detokenized and stored as an attribute of the InferenceRequest. These requests are marked as completed.
@@ -160,16 +161,18 @@ The following is what happens in the [simple_gpt_batch_inference.py](./gpt/simpl
#### 3. Customizing The Inference Pipeline
-The following guide will walk you through how you can customize different parts of the inference pipeline. There are three levels at which you can customize the pipeline.
-* **Inference engine** - Highest level of customization. Currently we support the MCore Engine. Change this to add a new engine.
-* **Text generation controller** - Extend this to customize tokenization, detokenization, or implement a new sampling strategy.
+
+The inference pipeline supports three levels of customization:
+
+* **Inference engine** - The MCore Engine is currently supported. Change this to add a new backend.
+* **Text generation controller** - The main sampling loop. This can be customized to support alternative tokenization, detokenization, or to implement a new sampling strategy.
* **Inference Wrapped Model** - Change this to support a new model.
* **Modify Inference Parameters** - Change this to update top_p, top_k, number of tokens to be generated, temperature, or other sampling parameters.
##### 3.1. Create Your Own Inference Backend
-This is the highest level of customization. The [abstract_engine.py](./../../megatron/core/inference/engine/abstract_engine.py) file has a generate method that can be extended to support a new backend.
+The [abstract_engine.py](./../../megatron/core/inference/engine/abstract_engine.py) file contains a `generate` method that can be extended to support a new backend.
```python
class AbstractEngine(ABC):
@@ -177,15 +180,17 @@ class AbstractEngine(ABC):
def generate(self) -> dict:
"""The abstract backend's generate function.
- To define your own backend, make sure you implement this and return the outputs as a dictionary .
-
+ To define a new backend, implement this method and return the outputs as a dictionary.
+```
-##### 3.2. Create Your Own Text Generation Controller
-In case you want to use the megatron core backend, but would like to overwrite the tokenization, text generation or detokenization extend the [simple_text_generation_controller.py](../../megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py). The class has the following methods
+##### 3.2. Implement a new Sampling Loop
+
+The [TextGenerationController](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py) contains the main sampling loop and can be modified to support new tokenization, detokenization, or sampling strategies.
+
``` python
-class SimpleTextGenerationController:
+class TextGenerationController:
def tokenize_prompt(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts"""
@@ -193,12 +198,12 @@ class SimpleTextGenerationController:
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
- common_inference_params: CommonInferenceParams,
+ sampling_params: SamplingParams,
vocab_size: int,
) -> torch.Tensor:
"""Samples the logits to generate outputs
- Given the logits of the last token, this function samples it according to the parameters defined in common_inference_params and returns the samples
+ Given the logits of the last token, this function samples according to the parameters defined in sampling_params and returns the sampled tokens.
"""
def update_generation_status(
@@ -229,12 +234,12 @@ class SimpleTextGenerationController:
##### 3.3. Support Other Models
-In order to support other models please extend the [abstract_model_inference_wrapper.py](./../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) file. The abstract wrapper already supports the following :
-* Forward method which automatically calls the appropriate forward method (PP or TP etc) depending on model parallel settings
-* Initalizes the model and puts it in eval mode
-* Obtains the input parameters (batch size, max seq length) and has an instance of the input
+Extend [abstract_model_inference_wrapper.py](./../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) to support other models. The abstract model wrapper implements:
+* Forward method which calls the model `forward` method depending on model parallel settings
+* Initializes the model and puts it in `.eval()` mode
+* Setup for the input parameters (max batch size, max seq length)
-The main methods to change for your model might be the following:
+The following methods should be implemented:
```python
class AbstractModelInferenceWrapper:
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
@@ -247,28 +252,28 @@ class AbstractModelInferenceWrapper:
def get_batch_for_context_window(self) -> List:
"""Returns the input data for inference
- This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
+ This function gets called iteratively in the inference loop. It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
```
-Refer to [gpt_inference_wrapper.py](../../megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py) for an example of extending this for GPTModel.
+Refer to [gpt_inference_wrapper.py](../../megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py) for an example of implementing this for GPTModel.
##### 3.3. Modify Inference Parameters
-We use [common inference params](../../megatron/core/inference/common_inference_params.py) for text generation. Customize this if you want to change top_p, top_k, number of tokens to generate etc. If you want to add other attributes that you would use in the inference loop, you can do that as shown below
+We use [common inference params](../../megatron/core/inference/sampling_params.py) for text generation. Customize this if you want to change top_p, top_k, number of tokens to generate etc. If you want to add other attributes that you would use in the inference loop, you can do that as shown below
```
-from megatron.core.inference.common_inference_params import CommonInferenceParams
+from megatron.core.inference.sampling_params import SamplingParams
-c = CommonInferenceParams(temperature=0.5)
+c = SamplingParams(temperature=0.5)
c.add_attributes({'min_length':4, 'eod_id':153})
```
#### 4. Future work
-The following are planned for the future releases .
+The following features are planned for the future releases.
* Dynamic batching
* Paged Attention
* TRTLLM Engine support
-* Support for Multimodal model inference
\ No newline at end of file
+* Support for multimodal inference
\ No newline at end of file
diff --git a/examples/inference/gpt/simple_gpt_batch_inference.py b/examples/inference/gpt/gpt_batch_inference.py
old mode 100755
new mode 100644
similarity index 91%
rename from examples/inference/gpt/simple_gpt_batch_inference.py
rename to examples/inference/gpt/gpt_batch_inference.py
index 5c7ae5bd773cd41437650caa01e06664c7e506c2..050b230cef70d56203b7f9270a6166d7251f0769
--- a/examples/inference/gpt/simple_gpt_batch_inference.py
+++ b/examples/inference/gpt/gpt_batch_inference.py
@@ -6,10 +6,10 @@ import sys
from argparse import Namespace
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
-from megatron.core.inference.common_inference_params import CommonInferenceParams
+from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.inference_request import InferenceRequest
-from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import SimpleTextGenerationController
+from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController
from megatron.core.transformer.module import MegatronModule
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
@@ -66,7 +66,7 @@ def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngi
)
inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config)
- text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer)
+ text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer)
return MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size)
def main():
@@ -89,7 +89,7 @@ def main():
inference_engine = get_inference_engine(args, model)
- common_inference_params = CommonInferenceParams(
+ sampling_params = SamplingParams(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
@@ -97,7 +97,7 @@ def main():
num_tokens_to_generate=args.num_tokens_to_generate)
results: List[InferenceRequest] = inference_engine.generate(
- prompts=args.prompts, common_inference_params=common_inference_params
+ prompts=args.prompts, sampling_params=sampling_params
)
if torch.distributed.get_rank() == 0:
diff --git a/examples/inference/llama_mistral/huggingface_reference.py b/examples/inference/llama_mistral/huggingface_reference.py
old mode 100755
new mode 100644
diff --git a/examples/inference/llama_mistral/run_text_generation_llama3.1.sh b/examples/inference/llama_mistral/run_text_generation_llama3.1.sh
old mode 100755
new mode 100644
diff --git a/examples/inference/llama_mistral/run_text_generation_llama3.sh b/examples/inference/llama_mistral/run_text_generation_llama3.sh
old mode 100755
new mode 100644
diff --git a/examples/inference/llama_mistral/run_text_generation_mistral.sh b/examples/inference/llama_mistral/run_text_generation_mistral.sh
old mode 100755
new mode 100644
diff --git a/examples/inference/run_text_generation_server_345M.sh b/examples/inference/run_text_generation_server_345M.sh
old mode 100755
new mode 100644
diff --git a/examples/inference/run_text_generation_server_345M_8_tensor_parallel.sh b/examples/inference/run_text_generation_server_345M_8_tensor_parallel.sh
old mode 100755
new mode 100644
diff --git a/examples/inference/t5/simple_t5_batch_inference.py b/examples/inference/t5/simple_t5_batch_inference.py
old mode 100755
new mode 100644
index 3f4557d3c2dac2ae1394adfae6d79899d9b0aa11..b4226d7de0f8352fd74bedf047559f0a7819ea84
--- a/examples/inference/t5/simple_t5_batch_inference.py
+++ b/examples/inference/t5/simple_t5_batch_inference.py
@@ -5,7 +5,7 @@ from argparse import Namespace
import torch
import pretrain_t5
-from megatron.core.inference.common_inference_params import CommonInferenceParams
+from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.inference_request import InferenceRequest
@@ -120,7 +120,7 @@ def main():
inference_engine = get_inference_engine(args, model)
- common_inference_params = CommonInferenceParams(
+ sampling_params = SamplingParams(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
@@ -138,7 +138,7 @@ def main():
prompts=args.prompts,
add_BOS=True,
encoder_prompts=args.encoder_prompts,
- common_inference_params=common_inference_params,
+ sampling_params=sampling_params,
)
if torch.distributed.get_rank() == 0:
diff --git a/examples/mamba/.gitignore b/examples/mamba/.gitignore
old mode 100755
new mode 100644
diff --git a/examples/mamba/Dockerfile b/examples/mamba/Dockerfile
old mode 100755
new mode 100644
diff --git a/examples/mamba/README.md b/examples/mamba/README.md
old mode 100755
new mode 100644
diff --git a/examples/mamba/run_text_gen_server_8b.sh b/examples/mamba/run_text_gen_server_8b.sh
old mode 100755
new mode 100644
diff --git a/examples/mamba/run_text_gen_server_8b_gpt3.sh b/examples/mamba/run_text_gen_server_8b_gpt3.sh
old mode 100755
new mode 100644
diff --git a/examples/mamba/train.sh b/examples/mamba/train.sh
old mode 100755
new mode 100644
diff --git a/examples/mixtral/README.md b/examples/mixtral/README.md
old mode 100755
new mode 100644
diff --git a/examples/mixtral/train_mixtral_8x7b_distributed.sh b/examples/mixtral/train_mixtral_8x7b_distributed.sh
old mode 100755
new mode 100644
diff --git a/examples/multimodal/Dockerfile b/examples/multimodal/Dockerfile
old mode 100755
new mode 100644
diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md
old mode 100755
new mode 100644
index 62e47567b939865fa73346dc8e452f18f02685b4..a65839f8f15f0ada9a38bc5081e74e6251c298d6
--- a/examples/multimodal/README.md
+++ b/examples/multimodal/README.md
@@ -16,7 +16,7 @@ You can build a docker container using `examples/multimodal/Dockerfile` to run t
### Language model
-Follow the instructions in [Mistral](../../docs/llama_mistral.md#mistral-7b) to download weights for Mistral-7B-Instruct-v0.3 (Base or Instruct) from HuggingFace and convert to mcore format with tensor parallel size 4.
+Follow the instructions in [Mistral](../../docs/llama_mistral.md#mistral-7b) to download weights for Mistral-7B-Instruct-v0.3 from HuggingFace and convert to mcore format with tensor parallel size 4.
Please use the tokenizer from HuggingFace.
### Vision model
@@ -113,7 +113,7 @@ Run the following script:
```
examples/multimodal/text_generation_mistral_clip.sh --input-image-path /path/to/input/images --output-path /some/output/directory \
- --model-path /path/to/model.pt --tokenizer-path /path/to/tokenizer/ --gt-path /path/to/groundtruth/file --task generation-task-name
+ --model-path /path/to/model.pt --gt-path /path/to/groundtruth/file --task generation-task-name
```
where `--task generation-task-name` is the name of the evaluation benchmark such as `captioning` or `MMMU`.
diff --git a/examples/multimodal/assets/pretrain_curves.png b/examples/multimodal/assets/pretrain_curves.png
old mode 100755
new mode 100644
diff --git a/examples/multimodal/combine_lm_vision_checkpoints.sh b/examples/multimodal/combine_lm_vision_checkpoints.sh
old mode 100755
new mode 100644
diff --git a/examples/multimodal/combine_state_dicts.py b/examples/multimodal/combine_state_dicts.py
old mode 100755
new mode 100644
diff --git a/examples/multimodal/config.py b/examples/multimodal/config.py
old mode 100755
new mode 100644
index 343fcd589688b3e5bf1254189450e4fb06b88b6f..ee404604b650d32f4535a53dfba24498d9ab4f77
--- a/examples/multimodal/config.py
+++ b/examples/multimodal/config.py
@@ -7,34 +7,20 @@ from megatron.training.activations import fast_gelu, quick_gelu, squared_relu
def get_language_model_config(config):
- if config.language_model_type == "2b":
+ if config.language_model_type == "llama3_8b":
+ config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = True
- config.apply_query_key_layer_scaling = True
- config.layernorm_zero_centered_gamma = True
- config.bias_dropout_fusion = False
- config.rotary_percent = 0.5
- config.apply_rope_fusion = False
- config.attention_softmax_in_fp32 = True
- elif config.language_model_type == "8b":
- config.add_bias_linear = False
- config.bias_activation_fusion = False
- config.gated_linear_unit = False
- config.apply_query_key_layer_scaling = True
- config.layernorm_zero_centered_gamma = True
+ config.apply_query_key_layer_scaling = False
+ config.layernorm_zero_centered_gamma = (
+ False # Zero centered gamma not supported for RMSNorm
+ )
config.bias_dropout_fusion = False
- config.rotary_percent = 0.5
- config.attention_dropout = 0.0
config.apply_rope_fusion = False
- config.activation_func = squared_relu
- config.ffn_hidden_size = 16384
- config.masked_softmax_fusion = True
config.attention_softmax_in_fp32 = True
- config.num_query_groups = 32
- config.kv_channels = 128
- config.rotary_interleaved = False
- elif config.language_model_type == "llama3_8b":
+ config.ffn_hidden_size = 14336
+ elif config.language_model_type == "mistral_7b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
@@ -47,7 +33,7 @@ def get_language_model_config(config):
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336
- elif config.language_model_type == "mistral_7b":
+ elif config.language_model_type == "yi-34b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
@@ -59,10 +45,11 @@ def get_language_model_config(config):
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
- config.ffn_hidden_size = 14336
- elif config.language_model_type == "yi-34b":
+ config.ffn_hidden_size = 20480
+ elif config.language_model_type == "qwen2.5_7B":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
+ config.add_qkv_bias = True
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False
@@ -72,7 +59,7 @@ def get_language_model_config(config):
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
- config.ffn_hidden_size = 20480
+ config.ffn_hidden_size = 18944
elif config.language_model_type == "qwen2.0_72B":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
@@ -168,13 +155,7 @@ def get_vision_projection_config(config, hidden_size):
config.bias_activation_fusion = False
config.add_bias_linear = False
config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model.
- if config.language_model_type == "2b":
- config.ffn_hidden_size = 5440
- config.activation_func = torch.nn.functional.gelu
- if config.language_model_type == "8b":
- config.ffn_hidden_size = 16384
- config.activation_func = squared_relu
- elif config.language_model_type == "llama3_8b":
+ if config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "mistral_7b":
@@ -185,6 +166,9 @@ def get_vision_projection_config(config, hidden_size):
config.ffn_hidden_size = 20480
config.normalization = "LayerNorm"
config.activation_func = torch.nn.functional.gelu
+ elif config.language_model_type == "qwen2.5_7B":
+ config.ffn_hidden_size = 3584
+ config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "qwen2.0_72B":
config.ffn_hidden_size = 29568
config.normalization = "LayerNorm"
diff --git a/examples/multimodal/convert_llava_pretrain_to_wds.py b/examples/multimodal/convert_llava_pretrain_to_wds.py
old mode 100755
new mode 100644
diff --git a/examples/multimodal/dataloader_provider.py b/examples/multimodal/dataloader_provider.py
old mode 100755
new mode 100644
diff --git a/examples/multimodal/dataset_helpers.py b/examples/multimodal/dataset_helpers.py
old mode 100755
new mode 100644
index de76f8e45e3a32e3e2a429128ee484d4185e39f9..ecbbc502c08bcda12d52c74eaabdbd3ffc3d774b
--- a/examples/multimodal/dataset_helpers.py
+++ b/examples/multimodal/dataset_helpers.py
@@ -2,16 +2,19 @@
import bisect
import dataclasses
import json
+import re
import sys
import traceback
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from image_processing import get_visual_transform
+from PIL import Image
+from torchvision.transforms import ToPILImage
import numpy as np
import torch
-from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN
+from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.energon import (
Batch,
@@ -175,6 +178,10 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
self.img_h, self.img_w = self.args.img_h, self.args.img_w
+ # This map is used to reduce the number of tiles used per image if the number of tokens is
+ # larger than the decoder_seq_length.
+ self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1}
+
def _get_total_seq_length(self, input_ids, num_tiles):
"""Calculate expected sequence length given text tokens length and number of tiles."""
total_num_images = len(num_tiles)
@@ -237,7 +244,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
- cur_prompt = "\n" + cur_prompt + "\n"
+ cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + "\n"
caption = sample.caption.strip()
@@ -282,7 +289,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
# LLAVA training: override text-prompt with just the image.
conv = [
# Note: no system message.
- {"role": "user", "content": "\n"},
+ {"role": "user", "content": IMAGE_TOKEN + "\n"},
{"role": "assistant", "content": sample.answers},
]
@@ -307,66 +314,130 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
"""Encode SFT sample."""
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False
- has_image = sample.__subflavors__['has_image'] if 'has_image' in sample.__subflavors__ else False
- has_image = has_image or (hasattr(sample, "images") and len(sample.images) > 0)
- if has_video:
- # Grab the selected frames of the video as a tensor with shape
- # fhwc: (num_frames, height, width, num_channels).
- video_fhwc = sample.images[0].permute(0, 2, 3, 1)
- selected_frames = torch.linspace(
- 0, video_fhwc.shape[0] - 1, self.args.num_frames).long()
- video_frame_fhwc = video_fhwc[selected_frames]
- imgs = []
- for video_frame_hwc in video_frame_fhwc:
- imgs += get_visual_transform(
- video_frame_hwc, self.img_h, self.img_w,
- self.args.use_tiling, self.args.max_num_tiles,
- self.args.use_thumbnail, augment, self.args.vision_model_type)
- num_tiles = [len(imgs)]
- elif has_image:
- imgs = get_visual_transform(
- sample.images[0], self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment,
- self.args.vision_model_type,
- )
- num_tiles = [len(imgs)]
- else:
- imgs = num_tiles = []
- sample.__key__ = "{}-{}".format("no-image", sample.__key__)
+ has_image = False
+ if hasattr(sample, "images"):
+ # If this is a text-only sample and we are freezing the LM,
+ # then use a dummy input image.
+ if len(sample.images) == 0 and self.args.freeze_LM:
+ empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255))
+ sample.images.append(empty_img)
+ if len(sample.images) > 0 and not has_video:
+ has_image = True
- conversation = []
# Note: Some tokenizers may ignore the system prompt.
- conversation.append({"role": "system", "content": "Answer the questions."})
-
- has_image_token = False
-
+ conversation = [{"role": "system", "content": "Answer the questions."}]
+ # Format the conversation as a list of "user" / "assistant" turns.
for text in sample.texts:
- if IMAGE_TOKEN in text["value"]:
- has_image_token = True
-
- if text["from"] == "human":
- role = "user"
- elif text["from"] == "gpt":
- role = "assistant"
- else:
- raise RuntimeError(f"unexpected role {text['from']} in {sample.texts}")
-
- turn = {"role": role, "content": text["value"]}
- conversation.append(turn)
-
- # If the sample contains an image but none of the user messages has an image token,
- # then add it to the first user message.
- if len(imgs) > 0 and not has_image_token:
+ error_msg = f"unexpected role {text['from']} in {sample.texts}"
+ assert text["from"] in ["human", "gpt"], error_msg
+ conversation.append({
+ "role": "user" if text["from"] == "human" else "assistant",
+ "content": text["value"]})
+
+ # Replace the image tags with IMAGE_TOKEN and count the number of image tags
+ number_image_tags = 0
+ image_tag_ids_list = []
+ for turn in conversation:
+ if turn["role"] == "user":
+ image_tag_ids = [int(x) - 1 for x in re.findall(r"", turn["content"])]
+ image_tag_ids_list.extend(image_tag_ids)
+ turn["content"] = re.sub(r"", IMAGE_TOKEN, turn["content"])
+ number_image_tags += turn["content"].count(IMAGE_TOKEN)
+ # For videos, we replace the image tag with the video tag
+ if has_video:
+ turn["content"] = turn["content"].replace(IMAGE_TOKEN, VIDEO_TOKEN)
+
+ # We re-order the images in sample.images according to how they appear in the conversation.
+ if len(image_tag_ids_list) > 0:
+ sample.images = [sample.images[idx] for idx in image_tag_ids_list]
+
+ # If there is only one image, but several image tags, we assume all the tags refer to the
+ # same image and duplicate the image:
+ if len(sample.images) == 1 and number_image_tags > 1:
+ sample.images = sample.images * number_image_tags
+
+ number_of_images = len(sample.images)
+ # Fail if there are more image or video tags than image or videos:
+ error_msg = (
+ f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}")
+ assert number_image_tags <= number_of_images, error_msg
+
+ # If there are less image of video tags than image or videos, prepend the tags to the first
+ # user message:
+ if number_image_tags < number_of_images:
for turn in conversation:
if turn["role"] == "user":
- turn["content"] = f"{IMAGE_TOKEN}\n" + turn["content"]
+ tag_to_add = VIDEO_TOKEN if has_video else IMAGE_TOKEN
+ turn["content"] = tag_to_add*(number_of_images-number_image_tags) + "\n" + turn["content"]
break
input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False)
+ if has_image:
+ imgs = []
+ num_tiles = []
+ max_num_tiles = self.args.max_num_tiles
+ # We keep a buffer of 4 tokens for the question,
+ # the rest can be used for image tokens.
+ max_image_token_allowed = self.args.decoder_seq_length - len(input_ids) - 4
+ # We start by extracting as many tiles per image as possible, and decrease the max
+ # number of tiles if there are too many image tokens.
+ while True:
+ imgs = []
+ num_tiles = []
+ for img in sample.images:
+ img_tiles = get_visual_transform(
+ img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles,
+ self.args.use_thumbnail, augment, self.args.vision_model_type)
+ imgs += img_tiles
+ num_tiles += [len(img_tiles)]
+ if max_num_tiles == 1:
+ break
+ if sum(num_tiles) * self.token_per_img_tile > max_image_token_allowed:
+ if max_num_tiles in self.num_tiles_degradation_map:
+ max_num_tiles = self.num_tiles_degradation_map[max_num_tiles]
+ else:
+ raise RuntimeError((
+ f"Tried to decrease the number of tiles {max_num_tiles} but it's not ",
+ f"defined in the degradation map {self.num_tiles_degradation_map}"))
+ else:
+ break
+ elif has_video:
+ # We don't use tiling for videos to limit the number of tokens.
+ use_tiling=False
+ # Grab the selected frames of the video as a tensor with shape
+ # fhwc: (num_frames, num_channels, height, width).
+ video_fchw = sample.images[0].permute(0, 1, 2, 3)
+ selected_frames = torch.linspace(
+ 0, video_fchw.shape[0] - 1, self.args.num_frames).long()
+ video_fchw = video_fchw[selected_frames]
+ imgs = []
+ for video_chw in video_fchw:
+ to_pil = ToPILImage()
+ video_chw = to_pil(video_chw)
+ imgs += get_visual_transform(
+ video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles,
+ self.args.use_thumbnail, augment, self.args.vision_model_type)
+ num_tiles = [len(imgs)]
+ else:
+ imgs = num_tiles = []
+
if self.is_packing_enabled:
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)
+ # Some final checks with respect to the number of image tokens and images on the tokenized
+ # conversation. There can still be errors, for instance if a non-video sample happens to
+ # have our pre-defined video token, or if the packing truncation removed a necessary image
+ # tag.
+ number_image_token = np.sum(input_ids == self.img_token_id)
+ error_msg = (
+ f"Found {number_image_token} image tokens for len({num_tiles}) = {len(num_tiles)} image tiles in {conversation}.")
+ assert number_image_token == len(num_tiles), error_msg
+ error_msg = (
+ f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.")
+ assert np.sum(num_tiles) == len(imgs), error_msg
+
return ImageTaskSample(
__key__=sample.__key__,
__restore_key__=sample.__restore_key__,
@@ -407,8 +478,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
if isinstance(sample, MultiChoiceVQASample):
cur_prompt = format_multichoice_question(sample.context, sample.choices)
- if "" not in cur_prompt:
- cur_prompt = "\n" + cur_prompt
+ if IMAGE_TOKEN not in cur_prompt:
+ cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
cur_answer = format_multichoice_answer(sample.correct_choice_idx)
elif isinstance(sample, VQASample):
if 'docvqa' in sample.__key__:
@@ -423,8 +494,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
cur_prompt = cur_prompt.format(sample.context)
- if "" not in cur_prompt:
- cur_prompt = "\n" + cur_prompt
+ if IMAGE_TOKEN not in cur_prompt:
+ cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
if isinstance(sample.answers, list):
answer_list = sample.answers
@@ -505,11 +576,11 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_list = self.manual_prompts["DocPretraining"]["raw"]
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
- if "" not in cur_prompt:
- cur_prompt = "\n" + cur_prompt
+ if IMAGE_TOKEN not in cur_prompt:
+ cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
- # Make sure there is no extra tag.
- sample.text = sample.text.replace("", "")
+ # Make sure there is no extra IMAGE_TOKEN tag.
+ sample.text = sample.text.replace(IMAGE_TOKEN, "")
caption = sample.text.strip()
@@ -526,8 +597,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
ref = sample.text
region = sample.words_boxes
- # Make sure there is no extra tag
- ref = ref.replace("", "")
+ # Make sure there is no extra IMAGE_TOKEN tag
+ ref = ref.replace(IMAGE_TOKEN, "")
if len(region) == 4:
region = f"({region[0]},{region[1]}),({region[2]},{region[3]})"
@@ -550,8 +621,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
cur_prompt = cur_prompt.format(prompt_content)
- if "" not in cur_prompt:
- cur_prompt = "\n" + cur_prompt
+ if IMAGE_TOKEN not in cur_prompt:
+ cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
return sample, cur_prompt, answer
@@ -559,8 +630,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
"""Format bbox coordinates as text."""
assert len(bbox) == 4 or len(bbox) == 8
- # Make sure there is no extra tag
- text = text.replace("", "")
+ # Make sure there is no extra IMAGE_TOKEN tag
+ text = text.replace(IMAGE_TOKEN, "")
if len(bbox) == 4:
label_str = f"[{text}]({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})"
@@ -582,8 +653,8 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
- if "" not in cur_prompt:
- cur_prompt = "\n" + cur_prompt
+ if IMAGE_TOKEN not in cur_prompt:
+ cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
cur_answer = answer
return sample, cur_prompt, cur_answer
diff --git a/examples/multimodal/evaluate_ai2d.py b/examples/multimodal/evaluation/evaluate_ai2d.py
old mode 100755
new mode 100644
similarity index 72%
rename from examples/multimodal/evaluate_ai2d.py
rename to examples/multimodal/evaluation/evaluate_ai2d.py
index 2d5db67b67d076e6d43a815997175325d5bb25ea..39b866ae4a030c2911a197fef6a1be0e19b0cfc4
--- a/examples/multimodal/evaluate_ai2d.py
+++ b/examples/multimodal/evaluation/evaluate_ai2d.py
@@ -9,19 +9,25 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D")
- results = []
+ results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
- results.append(
- {
- "question_id": res["sample_id"],
- "answer": res["answer"],
- "gt_answer": res["gt_answer"],
- }
- )
+ sample_id = res["sample_id"]
+
+ # Ignore possible duplicates.
+ if sample_id in results:
+ continue
+
+ results[sample_id] = {
+ "question_id": sample_id,
+ "answer": res["answer"],
+ "gt_answer": res["gt_answer"],
+ }
+
+ results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
diff --git a/examples/multimodal/evaluate_chartqa.py b/examples/multimodal/evaluation/evaluate_chartqa.py
old mode 100755
new mode 100644
similarity index 77%
rename from examples/multimodal/evaluate_chartqa.py
rename to examples/multimodal/evaluation/evaluate_chartqa.py
index e9238069d463a038c0e1b52e571e930c01b24b6a..53d4944f46e364b4cb68f8ef22dabccbf66ef3ca
--- a/examples/multimodal/evaluate_chartqa.py
+++ b/examples/multimodal/evaluation/evaluate_chartqa.py
@@ -9,15 +9,22 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="ChartQA")
- results = []
+ results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
- res["question_id"] = res["sample_id"]
+ sample_id = res["sample_id"]
- results.append(res)
+ # Ignore possible duplicates.
+ if sample_id in results:
+ continue
+
+ res["question_id"] = sample_id
+ results[sample_id] = res
+
+ results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
diff --git a/examples/multimodal/evaluate_coco.py b/examples/multimodal/evaluation/evaluate_coco.py
old mode 100755
new mode 100644
similarity index 77%
rename from examples/multimodal/evaluate_coco.py
rename to examples/multimodal/evaluation/evaluate_coco.py
index a717090c9274781f7aabd0f5cfbc3b8b032fc689..8eeb367e8f3bb0c38bd3b0f44b8f54f0c7d32636
--- a/examples/multimodal/evaluate_coco.py
+++ b/examples/multimodal/evaluation/evaluate_coco.py
@@ -11,20 +11,28 @@ def convert_to_coco_format(input_path):
"""Convert input files to COCO compatible format."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="captioning")
- captions = []
+ results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
+ sample_id = res["sample_id"]
- question_id = res['sample_id']
- caption = res['caption'].rstrip('.').lower()
+ # Ignore possible duplicates.
+ if sample_id in results:
+ continue
- captions.append({"image_id": question_id, "caption": caption})
+ caption = res["caption"].rstrip(".").lower()
+ results[sample_id] = {
+ "image_id": sample_id,
+ "caption": caption,
+ }
+
+ results = list(results.values())
with open(output_file_path, "w") as output_file:
- json.dump(captions, output_file, indent=4)
+ json.dump(results, output_file, indent=4)
return output_file_path
diff --git a/examples/multimodal/evaluate_mathvista.py b/examples/multimodal/evaluation/evaluate_mathvista.py
old mode 100755
new mode 100644
similarity index 92%
rename from examples/multimodal/evaluate_mathvista.py
rename to examples/multimodal/evaluation/evaluate_mathvista.py
index 3474c5f25e9e750ba4f77238b82ef8aaa4d7193b..a55f312f21986fb46644eb4e36979c342a2b7411
--- a/examples/multimodal/evaluate_mathvista.py
+++ b/examples/multimodal/evaluation/evaluate_mathvista.py
@@ -11,13 +11,21 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista")
- results = []
+ results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
- results.append(res)
+ sample_id = res["sample_id"]
+
+ # Remove possible duplicates.
+ if sample_id in results:
+ continue
+
+ results[sample_id] = res
+
+ results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
diff --git a/examples/multimodal/evaluate_mmmu.py b/examples/multimodal/evaluation/evaluate_mmmu.py
old mode 100755
new mode 100644
similarity index 91%
rename from examples/multimodal/evaluate_mmmu.py
rename to examples/multimodal/evaluation/evaluate_mmmu.py
index 66118fa905d69df3a1d2a07e9baa6236dd11d823..798c42bfa76009653927aa4f1339411807fb905f
--- a/examples/multimodal/evaluate_mmmu.py
+++ b/examples/multimodal/evaluation/evaluate_mmmu.py
@@ -2,9 +2,15 @@ import argparse
import glob
import json
import os
+import sys
import re
import subprocess
+# Get the absolute path of the parent directory
+parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
+# Add the parent directory to sys.path
+sys.path.insert(0, parent_dir)
+
from run_text_generation import get_output_path
from config import EvaluationConfig
@@ -48,6 +54,10 @@ def convert_to_mmmu_format(input_path):
)
# MMMU eval script expects just a sample_id to prediction mapping.
+ # Skip possible duplicates.
+ if sample_id in output:
+ continue
+
output[sample_id] = prediction
with open(output_file_path, "w") as output_file:
diff --git a/examples/multimodal/evaluate_ocrbench.py b/examples/multimodal/evaluation/evaluate_ocrbench.py
old mode 100755
new mode 100644
similarity index 95%
rename from examples/multimodal/evaluate_ocrbench.py
rename to examples/multimodal/evaluation/evaluate_ocrbench.py
index bc2b901065f53255a0cf4cabaa4893122d579566..b37473a67dbaeef121e734340a6161358ac0203b
--- a/examples/multimodal/evaluate_ocrbench.py
+++ b/examples/multimodal/evaluation/evaluate_ocrbench.py
@@ -8,13 +8,21 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench")
- results = []
+ results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
- results.append(res)
+ sample_id = res["sample_id"]
+
+ # Remove possible duplicates.
+ if sample_id in results:
+ continue
+
+ results[sample_id] = res
+
+ results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
diff --git a/examples/multimodal/evaluate_textvqa.py b/examples/multimodal/evaluation/evaluate_textvqa.py
old mode 100755
new mode 100644
similarity index 72%
rename from examples/multimodal/evaluate_textvqa.py
rename to examples/multimodal/evaluation/evaluate_textvqa.py
index c9bba7134ba9f7e3a925dbcc529ec97da60fac92..af782bdf0318b664e37d9a106e36e66e5f5ad63c
--- a/examples/multimodal/evaluate_textvqa.py
+++ b/examples/multimodal/evaluation/evaluate_textvqa.py
@@ -9,22 +9,25 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="TextVQA")
- results = []
+ results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
- results.append(
- {
- "question_id": res["sample_id"],
- "answer": res["answer"],
- "gt_answer": res["gt_answer"],
- }
- )
-
- # Make order deterministic.
- # results = sorted(results, key=lambda d: d["question_id"])
+ sample_id = res["sample_id"]
+
+ # Remove possible duplicates.
+ if sample_id in results:
+ continue
+
+ results[sample_id] = {
+ "question_id": sample_id,
+ "answer": res["answer"],
+ "gt_answer": res["gt_answer"],
+ }
+
+ results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
diff --git a/examples/multimodal/evaluate_vqav2.py b/examples/multimodal/evaluation/evaluate_vqav2.py
old mode 100755
new mode 100644
similarity index 88%
rename from examples/multimodal/evaluate_vqav2.py
rename to examples/multimodal/evaluation/evaluate_vqav2.py
index 0b1b9209bef3bfb5bd644ed28d5464c951965654..7807d80723f5aa67c7fcadd695e78643fd52cb6d
--- a/examples/multimodal/evaluate_vqav2.py
+++ b/examples/multimodal/evaluation/evaluate_vqav2.py
@@ -9,15 +9,22 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2")
- results = []
+ results = dict()
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
- res["question_id"] = res["sample_id"]
+ sample_id = res["sample_id"]
- results.append(res)
+ # Skip possible duplicates.
+ if sample_id in results:
+ continue
+
+ res["question_id"] = sample_id
+ results[sample_id] = res
+
+ results = list(results.values())
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
@@ -57,6 +64,9 @@ def compute_vqa_accuracy(result_file, task):
assert len(gt) == 1, "expected exactly one groundtruth answer."
gt = gt[0]
+ pred = pred.rstrip("%")
+ gt = gt.rstrip("%")
+
if is_number(pred) and is_number(gt):
pred = float(pred)
gt = float(gt)
diff --git a/examples/multimodal/evaluation_datasets.py b/examples/multimodal/evaluation/evaluation_datasets.py
old mode 100755
new mode 100644
similarity index 88%
rename from examples/multimodal/evaluation_datasets.py
rename to examples/multimodal/evaluation/evaluation_datasets.py
index 97f9ba926f1435960444626c3af41496d1bea837..50a50d56871bddd9de59c3b1444186c749892db8
--- a/examples/multimodal/evaluation_datasets.py
+++ b/examples/multimodal/evaluation/evaluation_datasets.py
@@ -188,7 +188,7 @@ class MMMUDataset(torch.utils.data.Dataset):
use_tiling,
max_num_tiles,
use_thumbnail,
- single_image,
+ prompt_style,
vision_model_type,
):
import datasets
@@ -246,7 +246,7 @@ class MMMUDataset(torch.utils.data.Dataset):
self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
- self._single_image = single_image
+ self._prompt_style = prompt_style
self._vision_model_type = vision_model_type
def __len__(self):
@@ -258,7 +258,7 @@ class MMMUDataset(torch.utils.data.Dataset):
sample = self._dataset[idx]
# Use the single image approach from the MMMU repo.
- if self._single_image:
+ if self._prompt_style == "single_image":
sample = process_single_sample(sample)
sample = construct_prompt(sample, self._config)
@@ -274,7 +274,69 @@ class MMMUDataset(torch.utils.data.Dataset):
vision_model_type=self._vision_model_type,
)
sample_num_tiles = [len(sample_imgs)]
- else:
+
+ prompt = sample["final_input_prompt"]
+ for i in range(8):
+ prompt = prompt.replace(f"", "")
+ sample["final_input_prompt"] = f"\n{prompt}"
+ elif self._prompt_style == "vlmevalkit":
+ sample = construct_prompt(sample, self._config)
+
+ if sample["question_type"] == "multiple-choice":
+ question = sample["question"]
+
+ options = ""
+ for k, v in sample["index2ans"].items():
+ options += f"{k}. {v}\n"
+
+ final_prompt = f"{question}\n"
+ if "hint" in sample:
+ final_prompt += f"Hint: {sample['hint']}\n"
+
+ if "task_instructions" in sample:
+ final_prompt += f"Task instructions: {sample['task_instructions']}\n"
+
+ final_prompt += options
+ final_prompt += "Answer with the option's letter from the given choices directly."
+
+ sample["final_input_prompt"] = final_prompt.rstrip()
+ else:
+ question = sample["question"]
+ final_prompt = f"{question}\n"
+ final_prompt += "Answer the question directly."
+ sample["final_input_prompt"] = final_prompt.rstrip()
+
+ sample_imgs = []
+ sample_num_tiles = []
+
+ img_indices = sorted(list(set(re.findall(r""
+
+ img = sample[img_key]
+ assert img is not None, f"{img_str} is in prompt but not in sample images"
+
+ imgs = get_visual_transform(
+ img,
+ self._img_h,
+ self._img_w,
+ self._use_tiling,
+ adjusted_max_num_tiles,
+ self._use_thumbnail,
+ augment=False,
+ vision_model_type=self._vision_model_type,
+ ) # List of tiles.
+
+ sample_imgs.extend(imgs)
+ sample_num_tiles.append(len(imgs))
+
+ sample["final_input_prompt"] = " ".join([f'' for i in range(len(img_indices))]) + "\n" + sample["final_input_prompt"]
+ elif self._prompt_style == "multi_image":
sample = construct_prompt(sample, self._config)
sample_imgs = []
@@ -315,6 +377,8 @@ class MMMUDataset(torch.utils.data.Dataset):
assert (
f"" not in sample["final_input_prompt"]
), "prompt contains unhandled image tags"
+ else:
+ raise ValueError(f"unknown prompt style {self._prompt_style}")
# MMMU specific metadata.
metadata = {"question_type": sample["question_type"]}
@@ -323,10 +387,6 @@ class MMMUDataset(torch.utils.data.Dataset):
metadata["all_choices"] = sample["all_choices"]
prompt = sample['final_input_prompt']
- if self._single_image:
- for i in range(8):
- prompt = prompt.replace(f"", "")
- prompt = f"\n{prompt}"
tile_count = torch.tensor(sample_num_tiles, dtype=torch.int)
@@ -780,8 +840,10 @@ def get_evaluation_dataset(
vision_model_type,
)
elif task == 'MMMU':
- # Note: single_image=True uses only one image like in the MMMU repo example.
- # single_image=False uses all images in the sample.
+ # Note:
+ # - prompt_style="single_image" uses only one image like in the MMMU repo example.
+ # - prompt_style="multi_image" uses multiple input images.
+ # - prompt_style="vlmevalkit" is similar to https://github.com/open-compass/VLMEvalKit/blob/5d3cebcf18ef4bfbadc3bd3ef80bdc7aad2c6557/vlmeval/vlm/internvl_chat.py#L499
dataset = MMMUDataset(
input_image_path,
num_samples_per_partition,
@@ -792,7 +854,7 @@ def get_evaluation_dataset(
use_tiling,
max_num_tiles,
use_thumbnail,
- single_image=True,
+ prompt_style="single_image",
vision_model_type=vision_model_type,
)
elif task == "VideoMME":
diff --git a/examples/multimodal/image_processing.py b/examples/multimodal/image_processing.py
old mode 100755
new mode 100644
diff --git a/examples/multimodal/layer_specs.py b/examples/multimodal/layer_specs.py
old mode 100755
new mode 100644
diff --git a/examples/multimodal/manual_prompts.json b/examples/multimodal/manual_prompts.json
old mode 100755
new mode 100644
diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py
old mode 100755
new mode 100644
index 6db834e97a1d643955cf12905eb3ed84f0541a08..a28a428325b8db9c7c1268080979889935dcc396
--- a/examples/multimodal/model.py
+++ b/examples/multimodal/model.py
@@ -136,6 +136,20 @@ def model_provider(
else:
vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules
+ # Toggle --recompute* for the vision and language model separately.
+ if args.recompute_vision:
+ if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None:
+ vision_config.recompute_num_layers = vision_config.num_layers
+ else:
+ vision_config.recompute_granularity = None
+ vision_config.recompute_method = None
+ vision_config.recompute_num_layers = None
+
+ vision_projection_config.recompute_granularity = None
+ vision_projection_config.recompute_method = None
+ vision_projection_config.recompute_num_layers = None
+
+
tokenizer = get_tokenizer()
image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
diff --git a/examples/multimodal/model_converter/clip_converter.py b/examples/multimodal/model_converter/clip_converter.py
old mode 100755
new mode 100644
diff --git a/examples/multimodal/model_converter/internvit_converter.py b/examples/multimodal/model_converter/internvit_converter.py
old mode 100755
new mode 100644
diff --git a/examples/multimodal/model_converter/siglip_converter.py b/examples/multimodal/model_converter/siglip_converter.py
old mode 100755
new mode 100644
diff --git a/examples/multimodal/model_converter/vision_model_tester.py b/examples/multimodal/model_converter/vision_model_tester.py
old mode 100755
new mode 100644
diff --git a/examples/multimodal/multimodal_args.py b/examples/multimodal/multimodal_args.py
old mode 100755
new mode 100644
index 4b2be450afb33ce985ea052c55ad01abc5a3c548..eb56118e71613ea7fae6f81ff44f2969f26b4533
--- a/examples/multimodal/multimodal_args.py
+++ b/examples/multimodal/multimodal_args.py
@@ -49,7 +49,7 @@ def add_multimodal_extra_args(parser):
group.add_argument(
"--tokenizer-prompt-format",
type=str,
- choices=["mistral", "llama3", "chatml", "nvlm-yi-34b", "qwen2p0"],
+ choices=["mistral", "llama3", "chatml", "nvlm-yi-34b", "qwen2p0", "qwen2p5"],
required=True,
help="Prompt format to use with the tokenizer.",
)
@@ -71,5 +71,9 @@ def add_multimodal_extra_args(parser):
group.add_argument(
"--packing-seq-length", type=int, default=0, help="Packing sequence length. Must be > 0 if using packing."
)
+ group.add_argument(
+ "--recompute-vision", action="store_true", default=False, help="Enable activation checkpointing in the vision model"
+ )
+
return parser
diff --git a/examples/multimodal/nvlm/README.md b/examples/multimodal/nvlm/README.md
old mode 100755
new mode 100644
index 7eddbb7efa9162edb02e118ce7bb5d95151ca944..bb576bb40355a02fbe2701fdaf85d6ee9a8058e3
--- a/examples/multimodal/nvlm/README.md
+++ b/examples/multimodal/nvlm/README.md
@@ -5,6 +5,13 @@ Please refer to the [NVLM paper](https://arxiv.org/pdf/2409.11402) for details.
*NOTE: VLMs in Megatron are under active development and are expected to change.*
+# Checkpoints
+
+NVLM 1.0 model weights are publicly available in HuggingFace and Megatron format.
+
+- NVLM-1.0-D 72B [HuggingFace version](https://huggingface.co/nvidia/NVLM-D-72B)
+- NVLM-1.0-D 72B [Megatron-Core version](https://huggingface.co/nvidia/NVLM-D-72B-mcore)
+
# Setup
## Docker image
@@ -32,7 +39,7 @@ NVLM 1.0 34B starts from [NousResearch/Nous-Hermes-2-Yi-34B](https://huggingface
Please download it and run the following command to convert it to Megatron format.
```
python tools/checkpoint/convert.py --bf16 --model-type GPT --loader llama_mistral --saver mcore --target-tensor-parallel-size 8 --checkpoint-type hf \
- --load-dir --save-dir