Commit 688448db authored by silencealiang's avatar silencealiang
Browse files

更新代码

parent a02a5490
Pipeline #2503 passed with stage
File mode changed from 100644 to 100755
#!/bin/bash #!/bin/bash
# Use: ./train.sh <data-path> <tokenizer-path> # Use: ./train.sh <data-path> <tokenizer-path>
MODEL_SCALE="800M" # or "8B" MODEL_SCALE="800M" # or "8B"
case "${MODEL_SCALE}" in case "${MODEL_SCALE}" in
"800M") "800M")
TENSOR_MODEL_PARALLEL_SIZE=1 TENSOR_MODEL_PARALLEL_SIZE=1
NUM_LAYERS=48 NUM_LAYERS=48
HIDDEN_SIZE=1024 HIDDEN_SIZE=1024
NUM_ATTENTION_HEADS=16 NUM_ATTENTION_HEADS=16
GLOBAL_BATCH_SIZE=32 GLOBAL_BATCH_SIZE=32
;; ;;
"8B") "8B")
TENSOR_MODEL_PARALLEL_SIZE=4 TENSOR_MODEL_PARALLEL_SIZE=4
NUM_LAYERS=56 NUM_LAYERS=56
HIDDEN_SIZE=4096 HIDDEN_SIZE=4096
NUM_ATTENTION_HEADS=32 NUM_ATTENTION_HEADS=32
GLOBAL_BATCH_SIZE=8 GLOBAL_BATCH_SIZE=8
;; ;;
*) *)
echo "Invalid version specified" echo "Invalid version specified"
exit 1 exit 1
;; ;;
esac esac
DATA_PATH=$1 DATA_PATH=$1
TOKENIZER_PATH=$2 TOKENIZER_PATH=$2
export NCCL_IB_SL=1 export NCCL_IB_SL=1
export CUDA_DEVICE_MAX_CONNECTIONS=1 export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_IB_TIMEOUT=19 export NCCL_IB_TIMEOUT=19
export NCCL_IB_QPS_PER_CONNECTION=4 export NCCL_IB_QPS_PER_CONNECTION=4
CHECKPOINT_DIR="./checkpoints" CHECKPOINT_DIR="./checkpoints"
DATACACHE_DIR="./data-cache" DATACACHE_DIR="./data-cache"
TENSORBOARD_DIR="./tensorboard" TENSORBOARD_DIR="./tensorboard"
mkdir -p ${CHECKPOINT_DIR} mkdir -p ${CHECKPOINT_DIR}
mkdir -p ${DATACACHE_DIR} mkdir -p ${DATACACHE_DIR}
mkdir -p ${TENSORBOARD_DIR} mkdir -p ${TENSORBOARD_DIR}
export TRITON_CACHE_DIR="./triton-cache/" export TRITON_CACHE_DIR="./triton-cache/"
export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager"
SEQ_LEN=4096 SEQ_LEN=4096
TRAIN_SAMPLES=73242188 # 300B tokens / 4096 TRAIN_SAMPLES=73242188 # 300B tokens / 4096
LR_WARMUP_SAMPLES=50000 LR_WARMUP_SAMPLES=50000
LR_DECAY_SAMPLES=73192188 # TRAIN_SAMPLES - LR_WARMUP_SAMPLES LR_DECAY_SAMPLES=73192188 # TRAIN_SAMPLES - LR_WARMUP_SAMPLES
options=" \ options=" \
--tensor-model-parallel-size ${TENSOR_MODEL_PARALLEL_SIZE} \ --tensor-model-parallel-size ${TENSOR_MODEL_PARALLEL_SIZE} \
--sequence-parallel \ --sequence-parallel \
--pipeline-model-parallel-size 1 \ --pipeline-model-parallel-size 1 \
--use-distributed-optimizer \ --use-distributed-optimizer \
--overlap-param-gather \ --overlap-param-gather \
--overlap-grad-reduce \ --overlap-grad-reduce \
--untie-embeddings-and-output-weights \ --untie-embeddings-and-output-weights \
--init-method-std 0.02 \ --init-method-std 0.02 \
--position-embedding-type none \ --position-embedding-type none \
--num-layers ${NUM_LAYERS} \ --num-layers ${NUM_LAYERS} \
--hidden-size ${HIDDEN_SIZE} \ --hidden-size ${HIDDEN_SIZE} \
--num-attention-heads ${NUM_ATTENTION_HEADS} \ --num-attention-heads ${NUM_ATTENTION_HEADS} \
--group-query-attention \ --group-query-attention \
--num-query-groups 8 \ --num-query-groups 8 \
--hybrid-attention-ratio 0.08 \ --hybrid-attention-ratio 0.08 \
--hybrid-mlp-ratio 0.5 \ --hybrid-mlp-ratio 0.5 \
--seq-length ${SEQ_LEN} \ --seq-length ${SEQ_LEN} \
--max-position-embeddings ${SEQ_LEN} \ --max-position-embeddings ${SEQ_LEN} \
--train-samples ${TRAIN_SAMPLES} \ --train-samples ${TRAIN_SAMPLES} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \ --lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--lr-decay-samples ${LR_DECAY_SAMPLES} \ --lr-decay-samples ${LR_DECAY_SAMPLES} \
--save ${CHECKPOINT_DIR} \ --save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \ --load ${CHECKPOINT_DIR} \
--data-path ${DATA_PATH} \ --data-path ${DATA_PATH} \
--data-cache-path ${DATACACHE_DIR} \ --data-cache-path ${DATACACHE_DIR} \
--split 99,1,0 \ --split 99,1,0 \
--tokenizer-type GPTSentencePieceTokenizer \ --tokenizer-type GPTSentencePieceTokenizer \
--tokenizer-model ${TOKENIZER_PATH} \ --tokenizer-model ${TOKENIZER_PATH} \
--distributed-backend nccl \ --distributed-backend nccl \
--micro-batch-size 4 \ --micro-batch-size 4 \
--global-batch-size ${GLOBAL_BATCH_SIZE} \ --global-batch-size ${GLOBAL_BATCH_SIZE} \
--lr 2.5e-4 \ --lr 2.5e-4 \
--min-lr 2.5e-5 \ --min-lr 2.5e-5 \
--lr-decay-style cosine \ --lr-decay-style cosine \
--weight-decay 0.1 \ --weight-decay 0.1 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--attention-dropout 0.0 \ --attention-dropout 0.0 \
--hidden-dropout 0.0 \ --hidden-dropout 0.0 \
--disable-bias-linear \ --disable-bias-linear \
--normalization RMSNorm \ --normalization RMSNorm \
--adam-beta1 0.9 \ --adam-beta1 0.9 \
--adam-beta2 0.95 \ --adam-beta2 0.95 \
--log-interval 10 \ --log-interval 10 \
--save-interval 2000 \ --save-interval 2000 \
--eval-interval 2000 \ --eval-interval 2000 \
--eval-iters 32 \ --eval-iters 32 \
--bf16 \ --bf16 \
--use-mcore-models \ --use-mcore-models \
--spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \
--no-create-attention-mask-in-dataloader \ --no-create-attention-mask-in-dataloader \
--tensorboard-dir ${TENSORBOARD_DIR}" --tensorboard-dir ${TENSORBOARD_DIR}"
torchrun --nproc_per_node 8 ../../pretrain_mamba.py ${options} torchrun --nproc_per_node 8 ../../pretrain_mamba.py ${options}
...@@ -7,13 +7,13 @@ do ...@@ -7,13 +7,13 @@ do
fi fi
done done
mpirun -np 256 --hostfile gptnodes \ mpirun -np 32 --hostfile hostfile_mixtral_8x7B \
--allow-run-as-root \ --allow-run-as-root \
--bind-to none \ --bind-to none \
--mca plm_rsh_no_tree_spawn 1 \ --mca plm_rsh_no_tree_spawn 1 \
train_GPT-MOE_567B.sh node002 --profiling=$profiling > output.log 2>&1 train_mixtral_8x7B_multinodes.sh node066 --profiling=$profiling > output.log 2>&1
wait wait
rm -rf CKPT rm -rf CKPT
#rm -rf mixtral_dataset/my-mixtral_text_document #rm -rf mixtral_dataset/my-mixtral_text_document
\ No newline at end of file
...@@ -4,18 +4,23 @@ for para in $* ...@@ -4,18 +4,23 @@ for para in $*
do do
if [[ $para == --profiling* ]];then if [[ $para == --profiling* ]];then
profiling=${para#*=} profiling=${para#*=}
export GPU_FLUSH_ON_EXECUTION=1
export HIP_DIRECT_DISPATCH=0
fi fi
done done
source /opt/dtk/env.sh
# Runs Mixtral 8x7B model # Runs Mixtral 8x7B model
source /opt/dtk/env.sh
# defauat env
CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR}))
export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH
export GLOG_minloglevel=3
export CUDA_DEVICE_MAX_CONNECTIONS=1 export CUDA_DEVICE_MAX_CONNECTIONS=1
export HSA_FORCE_FINE_GRAIN_PCIE=1 export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
export GPU_MAX_HW_QUEUES=10 export GPU_MAX_HW_QUEUES=10
# nccl env
export NCCL_ALGO=Ring export NCCL_ALGO=Ring
export NCCL_MIN_NCHANNELS=32 export NCCL_MIN_NCHANNELS=32
export NCCL_MAX_NCHANNELS=32 export NCCL_MAX_NCHANNELS=32
...@@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7 ...@@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7
export NCCL_NET_GDR_READ=1 export NCCL_NET_GDR_READ=1
export RCCL_SDMA_COPY_ENABLE=0 export RCCL_SDMA_COPY_ENABLE=0
export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
#export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml"
# enable BatchLinear
export GROUPED_GEMM_BatchLinear=1 export GROUPED_GEMM_BatchLinear=1
export GLOG_minloglevel=3
RANK=$OMPI_COMM_WORLD_RANK RANK=$OMPI_COMM_WORLD_RANK
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
...@@ -75,7 +81,7 @@ MOE_ARGS=( ...@@ -75,7 +81,7 @@ MOE_ARGS=(
--moe-token-dispatcher-type alltoall --moe-token-dispatcher-type alltoall
--moe-expert-capacity-factor 0.5 --moe-expert-capacity-factor 0.5
--moe-pad-expert-input-to-capacity --moe-pad-expert-input-to-capacity
--moe-grouped-gemm #--moe-grouped-gemm
) )
DATA_ARGS=( DATA_ARGS=(
...@@ -103,25 +109,17 @@ TRAINING_ARGS=( ...@@ -103,25 +109,17 @@ TRAINING_ARGS=(
TORCH_PROFIE_ARGS=( TORCH_PROFIE_ARGS=(
--profile --profile
--profile-ranks 0 1 2 3 4 5 6 7 8 --profile-ranks 0 1 2 3 4 5 6 7
--profile-step-start 3 --profile-step-start 3
--profile-step-end 4 --profile-step-end 4
--profile-dir torch_prof_mixtral_1nodes --profile-dir torch_prof_mixtral_1nodes_tp2-pp1-ep8-ep_tp1
--use-pytorch-profiler --use-pytorch-profiler
) )
HIP_PROFIE_ARGS=(
--profile
--profile-ranks 0 1 2 3 4 5 6 7 8
--profile-step-start 4
--profile-step-end 5
--use-hip-profiler
)
MODEL_PARALLEL_ARGS=( MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 2 --tensor-model-parallel-size 2
--pipeline-model-parallel-size 1 --pipeline-model-parallel-size 1
--expert-model-parallel-size 2 --expert-model-parallel-size 8
--expert-tensor-parallel-size 1 --expert-tensor-parallel-size 1
--use-distributed-optimizer --use-distributed-optimizer
--sequence-parallel --sequence-parallel
...@@ -159,10 +157,6 @@ APP="python3 -u pretrain_gpt.py \ ...@@ -159,10 +157,6 @@ APP="python3 -u pretrain_gpt.py \
if [[ $profiling == "torch" ]]; then if [[ $profiling == "torch" ]]; then
APP+=" ${TORCH_PROFIE_ARGS[@]}" APP+=" ${TORCH_PROFIE_ARGS[@]}"
elif [[ $profiling == "hip" ]]; then
mkdir -p hip_prof_data
APP+=" ${HIP_PROFIE_ARGS[@]}"
APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}"
fi fi
#for hygon cpu #for hygon cpu
......
...@@ -4,18 +4,23 @@ for para in $* ...@@ -4,18 +4,23 @@ for para in $*
do do
if [[ $para == --profiling* ]];then if [[ $para == --profiling* ]];then
profiling=${para#*=} profiling=${para#*=}
export GPU_FLUSH_ON_EXECUTION=1
export HIP_DIRECT_DISPATCH=0
fi fi
done done
source /opt/dtk/env.sh
# Runs Mixtral 8x7B model # Runs Mixtral 8x7B model
source /opt/dtk/env.sh
# defauat env
CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR}))
export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH
export GLOG_minloglevel=3
export CUDA_DEVICE_MAX_CONNECTIONS=1 export CUDA_DEVICE_MAX_CONNECTIONS=1
export HSA_FORCE_FINE_GRAIN_PCIE=1 export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
export GPU_MAX_HW_QUEUES=10 export GPU_MAX_HW_QUEUES=10
# nccl env
export NCCL_ALGO=Ring export NCCL_ALGO=Ring
export NCCL_MIN_NCHANNELS=32 export NCCL_MIN_NCHANNELS=32
export NCCL_MAX_NCHANNELS=32 export NCCL_MAX_NCHANNELS=32
...@@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7 ...@@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7
export NCCL_NET_GDR_READ=1 export NCCL_NET_GDR_READ=1
export RCCL_SDMA_COPY_ENABLE=0 export RCCL_SDMA_COPY_ENABLE=0
export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
#export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml"
# enable BatchLinear
export GROUPED_GEMM_BatchLinear=1 export GROUPED_GEMM_BatchLinear=1
export GLOG_minloglevel=3
RANK=$OMPI_COMM_WORLD_RANK RANK=$OMPI_COMM_WORLD_RANK
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
...@@ -99,9 +105,6 @@ TRAINING_ARGS=( ...@@ -99,9 +105,6 @@ TRAINING_ARGS=(
--bf16 --bf16
--overlap-param-gather --overlap-param-gather
--overlap-grad-reduce --overlap-grad-reduce
--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1
) )
TORCH_PROFIE_ARGS=( TORCH_PROFIE_ARGS=(
...@@ -109,23 +112,15 @@ TORCH_PROFIE_ARGS=( ...@@ -109,23 +112,15 @@ TORCH_PROFIE_ARGS=(
--profile-ranks 0 1 2 3 8 9 10 11 --profile-ranks 0 1 2 3 8 9 10 11
--profile-step-start 3 --profile-step-start 3
--profile-step-end 4 --profile-step-end 4
--profile-dir torch_prof_data_mixtral_2nodes --profile-dir torch_prof_mixtral_4nodes_tp2-pp8-ep2-ep_tp1
--use-pytorch-profiler --use-pytorch-profiler
) )
HIP_PROFIE_ARGS=(
--profile
--profile-ranks 0 1 2 3 8 9 10 11
--profile-step-start 4
--profile-step-end 5
--use-hip-profiler
)
MODEL_PARALLEL_ARGS=( MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 4 --tensor-model-parallel-size 2
--pipeline-model-parallel-size 4 --pipeline-model-parallel-size 8
--expert-model-parallel-size 2 --expert-model-parallel-size 2
--expert-tensor-parallel-size 2 --expert-tensor-parallel-size 1
--use-distributed-optimizer --use-distributed-optimizer
--sequence-parallel --sequence-parallel
) )
...@@ -162,10 +157,6 @@ APP="python3 -u pretrain_gpt.py \ ...@@ -162,10 +157,6 @@ APP="python3 -u pretrain_gpt.py \
if [[ $profiling == "torch" ]]; then if [[ $profiling == "torch" ]]; then
APP+=" ${TORCH_PROFIE_ARGS[@]}" APP+=" ${TORCH_PROFIE_ARGS[@]}"
elif [[ $profiling == "hip" ]]; then
mkdir -p hip_prof_data
APP+=" ${HIP_PROFIE_ARGS[@]}"
APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}"
fi fi
#for hygon cpu #for hygon cpu
......
#!/bin/bash
# Runs Mixtral 8x7B model
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6000"}
NNODES=${SLURM_NNODES:-"1"}
NODE_RANK=${RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
CHECKPOINT_PATH=$1
TOKENIZER_MODEL=$2
DATA_PATH=$3
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NNODES
--node_rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
MODEL_ARGS=(
--use-mcore-models
--disable-bias-linear
--seq-length 4096
--max-position-embeddings 32768
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 14336
--num-attention-heads 32
--init-method-std 0.01
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--group-query-attention
--num-query-groups 8
--no-masked-softmax-fusion
--no-position-embedding
--rotary-base 1000000
)
MOE_ARGS=(
--num-experts 8
--moe-router-topk 2
--moe-router-load-balancing-type aux_loss
--moe-aux-loss-coeff 1e-2
--moe-grouped-gemm
--moe-token-dispatcher-type alltoall
--overlap-param-gather
--overlap-grad-reduce
)
DATA_ARGS=(
--tokenizer-type Llama2Tokenizer
--tokenizer-model ${TOKENIZER_MODEL}
--data-path $DATA_PATH
--split 99990,8,2
)
TRAINING_ARGS=(
--micro-batch-size 1
--global-batch-size 256
--lr 1e-4
--train-iters 500000
--lr-decay-iters 320000
--lr-decay-style cosine
--min-lr 1.0e-5
--weight-decay 0.1
--lr-warmup-iters 500
--clip-grad 1.0
--bf16
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 4
--expert-model-parallel-size 8
--use-distributed-optimizer
--sequence-parallel
)
LOGGING_ARGS=(
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \
--no-load-optim \
--no-load-rng
)
if [ -n "${WANDB_API_KEY}" ]; then
LOGGING_ARGS+=(
--wandb-project ${WANDB_PROJECT:-"Mixtral"}
--wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"}
)
fi
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]}
#/bin/bash #/bin/bash
MCORE_LM=$1 # <path_to_mcore_lm_model_folder> MCORE_LM=$1 # <path_to_mcore_lm_model_folder>
MCORE_VISION=$2 # <path_to_mcore_vision_model_folder> MCORE_VISION=$2 # <path_to_mcore_vision_model_folder>
OUTPUT_DIR=$3 # <path_to_output_folder_for_combined_checkpoint> OUTPUT_DIR=$3 # <path_to_output_folder_for_combined_checkpoint>
MODEL_TYPE=$4 # Model type. Default: Mistral CLIP example. MODEL_TYPE=$4 # Model type. Default: Mistral CLIP example.
if [[ $MODEL_TYPE == "nvlm" ]]; then if [[ $MODEL_TYPE == "nvlm" ]]; then
# NVLM TP=8 # NVLM TP=8
python examples/multimodal/combine_state_dicts.py \ python examples/multimodal/combine_state_dicts.py \
--input \ --input \
${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_04/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_04/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_04/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_04/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_05/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_05/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_05/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_05/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_06/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_06/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_06/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_06/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_07/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_07/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_07/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_07/model_optim_rng.pt \
--prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model \
--output \ --output \
${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_04/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_04/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_05/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_05/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_06/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_06/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_07/model_optim_rng.pt ${OUTPUT_DIR}/iter_0000001/mp_rank_07/model_optim_rng.pt
else else
# Mistral CLIP example TP=4. # Mistral CLIP example TP=4.
python examples/multimodal/combine_state_dicts.py \ python examples/multimodal/combine_state_dicts.py \
--input \ --input \
${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \
${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \
${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \
--prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model \
--output \ --output \
${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \
${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt
fi fi
echo 1 > ${OUTPUT_DIR}/latest_checkpointed_iteration.txt echo 1 > ${OUTPUT_DIR}/latest_checkpointed_iteration.txt
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from megatron.training.activations import fast_gelu, quick_gelu, squared_relu from megatron.training.activations import fast_gelu, quick_gelu, squared_relu
def get_language_model_config(config): def get_language_model_config(config):
if config.language_model_type == "llama3_8b": if config.language_model_type == "llama3_8b":
config.activation_func = torch.nn.functional.silu config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False config.add_bias_linear = False
config.bias_activation_fusion = False config.bias_activation_fusion = False
config.gated_linear_unit = True config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = ( config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm False # Zero centered gamma not supported for RMSNorm
) )
config.bias_dropout_fusion = False config.bias_dropout_fusion = False
config.apply_rope_fusion = False config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336 config.ffn_hidden_size = 14336
elif config.language_model_type == "mistral_7b": elif config.language_model_type == "llama3.1_8b":
config.activation_func = torch.nn.functional.silu config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False config.add_bias_linear = False
config.bias_activation_fusion = False config.bias_activation_fusion = False
config.gated_linear_unit = True config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = ( config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm False # Zero centered gamma not supported for RMSNorm
) )
config.bias_dropout_fusion = False config.bias_dropout_fusion = False
config.apply_rope_fusion = False config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336 config.ffn_hidden_size = 14336
elif config.language_model_type == "yi-34b": elif config.language_model_type == "llama3.1_70B":
config.activation_func = torch.nn.functional.silu config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False config.add_bias_linear = False
config.bias_activation_fusion = False config.bias_activation_fusion = False
config.gated_linear_unit = True config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = ( config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm False # Zero centered gamma not supported for RMSNorm
) )
config.bias_dropout_fusion = False config.bias_dropout_fusion = False
config.apply_rope_fusion = False config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 20480 config.ffn_hidden_size = 28672
elif config.language_model_type == "qwen2.5_7B": elif config.language_model_type == "mistral_7b":
config.activation_func = torch.nn.functional.silu config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False config.add_bias_linear = False
config.add_qkv_bias = True config.bias_activation_fusion = False
config.bias_activation_fusion = False config.gated_linear_unit = True
config.gated_linear_unit = True config.apply_query_key_layer_scaling = False
config.apply_query_key_layer_scaling = False config.layernorm_zero_centered_gamma = (
config.layernorm_zero_centered_gamma = ( False # Zero centered gamma not supported for RMSNorm
False # Zero centered gamma not supported for RMSNorm )
) config.bias_dropout_fusion = False
config.bias_dropout_fusion = False config.apply_rope_fusion = False
config.apply_rope_fusion = False config.attention_softmax_in_fp32 = True
config.attention_softmax_in_fp32 = True config.ffn_hidden_size = 14336
config.ffn_hidden_size = 18944 elif config.language_model_type == "yi-34b":
elif config.language_model_type == "qwen2.0_72B": config.activation_func = torch.nn.functional.silu
config.activation_func = torch.nn.functional.silu config.add_bias_linear = False
config.add_bias_linear = False config.bias_activation_fusion = False
config.add_qkv_bias = True config.gated_linear_unit = True
config.bias_activation_fusion = False config.apply_query_key_layer_scaling = False
config.gated_linear_unit = True config.layernorm_zero_centered_gamma = (
config.apply_query_key_layer_scaling = False False # Zero centered gamma not supported for RMSNorm
config.layernorm_zero_centered_gamma = ( )
False # Zero centered gamma not supported for RMSNorm config.bias_dropout_fusion = False
) config.apply_rope_fusion = False
config.bias_dropout_fusion = False config.attention_softmax_in_fp32 = True
config.apply_rope_fusion = False config.ffn_hidden_size = 20480
config.attention_softmax_in_fp32 = True elif config.language_model_type == "qwen2.5_7B":
config.ffn_hidden_size = 29568 config.activation_func = torch.nn.functional.silu
else: config.add_bias_linear = False
raise ValueError(f"unknown language model type {config.language_model_type}") config.add_qkv_bias = True
config.bias_activation_fusion = False
return config config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = (
def get_vision_model_config(config, apply_query_key_layer_scaling): False # Zero centered gamma not supported for RMSNorm
if config.vision_model_type == "clip": )
config.num_layers = 24 config.bias_dropout_fusion = False
config.num_attention_heads = 16 config.apply_rope_fusion = False
config.add_bias_linear = True config.attention_softmax_in_fp32 = True
config.add_qkv_bias = True config.ffn_hidden_size = 18944
config.hidden_size = 1024 elif config.language_model_type == "qwen2.0_72B":
config.hidden_dropout = 0.0 config.activation_func = torch.nn.functional.silu
config.attention_dropout = 0.0 config.add_bias_linear = False
config.ffn_hidden_size = 4096 config.add_qkv_bias = True
config.gated_linear_unit = False config.bias_activation_fusion = False
config.activation_func = quick_gelu config.gated_linear_unit = True
config.kv_channels = 64 config.apply_query_key_layer_scaling = False
config.num_query_groups = 16 config.layernorm_zero_centered_gamma = (
config.layernorm_zero_centered_gamma = False False # Zero centered gamma not supported for RMSNorm
config.apply_query_key_layer_scaling = apply_query_key_layer_scaling )
config.bias_activation_fusion = False config.bias_dropout_fusion = False
config.bias_dropout_fusion = False config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True config.attention_softmax_in_fp32 = True
config.normalization = 'LayerNorm' config.ffn_hidden_size = 29568
config.apply_rope_fusion = False elif config.language_model_type == "llama3.2_1b":
elif config.vision_model_type == "siglip": config.activation_func = torch.nn.functional.silu
config.num_layers = 27 config.add_bias_linear = False
config.num_attention_heads = 16 config.bias_activation_fusion = False
config.add_bias_linear = True config.gated_linear_unit = True
config.add_qkv_bias = True config.apply_query_key_layer_scaling = False
config.hidden_size = 1152 config.layernorm_zero_centered_gamma = (
config.hidden_dropout = 0.0 False # Zero centered gamma not supported for RMSNorm
config.attention_dropout = 0.0 )
config.ffn_hidden_size = 4304 config.bias_dropout_fusion = False
config.gated_linear_unit = False config.apply_rope_fusion = False
config.activation_func = fast_gelu config.attention_softmax_in_fp32 = True
config.kv_channels = 72 config.ffn_hidden_size = 8192
config.num_query_groups = 16 elif config.language_model_type.startswith("huggingface"):
config.layernorm_zero_centered_gamma = False # Loaded from HuggingFace config file.
config.apply_query_key_layer_scaling = apply_query_key_layer_scaling pass
config.bias_activation_fusion = False else:
config.bias_dropout_fusion = False raise ValueError(f"unknown language model type {config.language_model_type}")
config.attention_softmax_in_fp32 = True
config.normalization = 'LayerNorm' return config
config.apply_rope_fusion = False
config.qk_layernorm = False
config.layernorm_epsilon = 1e-6 def get_vision_model_config(config, apply_query_key_layer_scaling):
elif config.vision_model_type == "internvit": if config.vision_model_type == "clip":
config.num_layers = 45 config.num_layers = 24
config.num_attention_heads = 32 # Padded for TP=8. config.num_attention_heads = 16
config.num_query_groups = 32 # Padded for TP=8. config.add_bias_linear = True
config.kv_channels = 128 config.add_qkv_bias = True
config.add_bias_linear = True config.hidden_size = 1024
config.add_qkv_bias = False config.hidden_dropout = 0.0
config.hidden_size = 3200 config.attention_dropout = 0.0
config.hidden_dropout = 0.0 config.ffn_hidden_size = 4096
config.attention_dropout = 0.0 config.gated_linear_unit = False
config.ffn_hidden_size = 12800 config.activation_func = quick_gelu
config.gated_linear_unit = False config.kv_channels = 64
config.activation_func = torch.nn.functional.gelu config.num_query_groups = 16
config.layernorm_zero_centered_gamma = False config.layernorm_zero_centered_gamma = False
config.apply_query_key_layer_scaling = apply_query_key_layer_scaling config.apply_query_key_layer_scaling = apply_query_key_layer_scaling
config.bias_activation_fusion = False config.bias_activation_fusion = False
config.bias_dropout_fusion = False config.bias_dropout_fusion = False
config.attention_softmax_in_fp32 = True config.attention_softmax_in_fp32 = True
config.normalization = 'RMSNorm' config.normalization = 'LayerNorm'
config.layernorm_epsilon = 1e-6 config.apply_rope_fusion = False
config.apply_rope_fusion = False elif config.vision_model_type == "siglip":
else: config.num_layers = 27
raise ValueError(f"unknown vision model type {config.vision_model_type}") config.num_attention_heads = 16
config.add_bias_linear = True
return config config.add_qkv_bias = True
config.hidden_size = 1152
config.hidden_dropout = 0.0
def get_vision_projection_config(config, hidden_size): config.attention_dropout = 0.0
config.gated_linear_unit = False config.ffn_hidden_size = 4304
config.bias_activation_fusion = False config.gated_linear_unit = False
config.add_bias_linear = False config.activation_func = fast_gelu
config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model. config.kv_channels = 72
if config.language_model_type == "llama3_8b": config.num_query_groups = 16
config.ffn_hidden_size = 14336 config.layernorm_zero_centered_gamma = False
config.activation_func = torch.nn.functional.gelu config.apply_query_key_layer_scaling = apply_query_key_layer_scaling
elif config.language_model_type == "mistral_7b": config.bias_activation_fusion = False
config.ffn_hidden_size = 14336 config.bias_dropout_fusion = False
config.activation_func = torch.nn.functional.gelu config.attention_softmax_in_fp32 = True
config.normalization = None config.normalization = 'LayerNorm'
elif config.language_model_type == "yi-34b": config.apply_rope_fusion = False
config.ffn_hidden_size = 20480 config.qk_layernorm = False
config.normalization = "LayerNorm" config.layernorm_epsilon = 1e-6
config.activation_func = torch.nn.functional.gelu elif config.vision_model_type == "internvit":
elif config.language_model_type == "qwen2.5_7B": config.num_layers = 45
config.ffn_hidden_size = 3584 config.num_attention_heads = ((24 // config.tensor_model_parallel_size) + 1) * config.tensor_model_parallel_size
config.activation_func = torch.nn.functional.gelu config.num_query_groups = config.num_attention_heads
elif config.language_model_type == "qwen2.0_72B": config.add_bias_linear = True
config.ffn_hidden_size = 29568 config.add_qkv_bias = False
config.normalization = "LayerNorm" config.hidden_size = 3200
config.activation_func = torch.nn.functional.gelu config.hidden_dropout = 0.0
else: config.attention_dropout = 0.0
raise ValueError(f"unknown language model type {config.language_model_type}") config.ffn_hidden_size = 12800
config.gated_linear_unit = False
return config config.activation_func = torch.nn.functional.gelu
config.layernorm_zero_centered_gamma = False
config.apply_query_key_layer_scaling = apply_query_key_layer_scaling
@dataclass config.bias_activation_fusion = False
class EvaluationConfig: config.bias_dropout_fusion = False
"""Evaluation related configuration.""" config.attention_softmax_in_fp32 = True
task: str config.normalization = 'RMSNorm'
config.layernorm_epsilon = 1e-6
temperature: float = 1.0 config.apply_rope_fusion = False
top_p: float = 0.0 elif config.vision_model_type == "radio":
top_k: int = 0 config.num_layers = 32
config.num_attention_heads = 16
out_seq_length: int = 32 config.add_bias_linear = True
config.add_qkv_bias = True
output_path: str = "" config.hidden_size = 1280
config.ffn_hidden_size = 5120
input_image_path: str = "" config.gated_linear_unit = False
gt_path: str = "" config.activation_func = fast_gelu
config.kv_channels = 80
num_partitions: int = 1 config.num_query_groups = 16
partition_id: int = 0 config.layernorm_zero_centered_gamma = False
num_samples_per_partition: int = 0 config.apply_query_key_layer_scaling = apply_query_key_layer_scaling
config.bias_activation_fusion = False
config.bias_dropout_fusion = False
config.attention_softmax_in_fp32 = True
config.normalization = 'LayerNorm'
config.apply_rope_fusion = False
config.qk_layernorm = False
config.layernorm_epsilon = 1e-6
elif config.vision_model_type.startswith("huggingface"):
# Loaded from HuggingFace config file.
pass
else:
raise ValueError(f"unknown vision model type {config.vision_model_type}")
return config
def get_vision_projection_config(config, hidden_size):
config.gated_linear_unit = False
config.bias_activation_fusion = False
config.add_bias_linear = False
config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model.
if config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "llama3.1_8b":
config.ffn_hidden_size = 4096
config.activation_func = torch.nn.functional.gelu
config.layernorm_epsilon = 1e-5
config.add_bias_linear = True
config.normalization = "LayerNorm"
elif config.language_model_type == "mistral_7b":
config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.gelu
config.normalization = None
elif config.language_model_type == "yi-34b":
config.ffn_hidden_size = 20480
config.normalization = "LayerNorm"
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "qwen2.5_7B":
config.ffn_hidden_size = 3584
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "qwen2.0_72B":
config.ffn_hidden_size = 29568
config.normalization = "LayerNorm"
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "llama3.2_1b":
config.ffn_hidden_size = 2048
config.activation_func = torch.nn.functional.gelu
config.normalization = "LayerNorm"
elif config.language_model_type.startswith("huggingface"):
config.activation_func = torch.nn.functional.gelu
from transformers import AutoConfig
hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path)
if "qwen" in hf_config.model_type:
config.ffn_hidden_size = 1536
else:
raise ValueError(f"unknown language model type {config.language_model_type}")
return config
@dataclass
class EvaluationConfig:
"""Evaluation related configuration."""
task: str
temperature: float = 1.0
top_p: float = 0.0
top_k: int = 0
out_seq_length: int = 32
output_path: str = ""
input_image_path: str = ""
gt_path: str = ""
num_partitions: int = 1
partition_id: int = 0
num_samples_per_partition: int = 0
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import bisect import bisect
import dataclasses import dataclasses
import json import json
import re import re
import sys import sys
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from image_processing import get_visual_transform from image_processing import find_closest_aspect_ratio, find_closest_area_weighted_aspect_ratio, get_visual_transform
from PIL import Image from PIL import Image
from torchvision.transforms import ToPILImage from torchvision.transforms import ToPILImage
import numpy as np import numpy as np
import torch import torch
from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN from energon_util import OfflineTargetAspectRatioSample, SampleListSample
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN
from megatron.energon import ( from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
Batch, from megatron.energon import (
CaptioningSample, Batch,
DefaultTaskEncoder, CaptioningSample,
OCRSample, DefaultTaskEncoder,
Sample, OCRSample,
SimilarityInterleavedSample, Sample,
VQASample, SimilarityInterleavedSample,
MultiChoiceVQASample VQASample,
) MultiChoiceVQASample
from megatron.energon.task_encoder.base import stateless )
from megatron.training import get_args, get_tokenizer from megatron.energon.task_encoder.base import stateless
from megatron.training import get_args, get_tokenizer
@dataclass
class ImageTaskSample(Sample): @dataclass
__key__: str class ImageTaskSample(Sample):
__restore_key__: Tuple[Union[str, int, tuple], ...] __key__: str
__subflavor__: Dict __restore_key__: Tuple[Union[str, int, tuple], ...]
__subflavors__: Dict __subflavor__: Dict
# (c, h, w) __subflavors__: Dict
imgs: List[torch.Tensor] # (c, h, w)
num_tiles: List[int] imgs: List[torch.Tensor]
tokens: torch.Tensor num_tiles: List[int]
total_len: int # Total token count in the sample, including text and image tokens tokens: torch.Tensor
labels: torch.Tensor = None total_len: int # Total token count in the sample, including text and image tokens
labels: torch.Tensor = None
@dataclass
class ImageTaskSamplePacked(Sample): @dataclass
"""Dataclass to store a single packed sample (not a batch). class ImageTaskSamplePacked(Sample):
"""Dataclass to store a single packed sample (not a batch).
P = Number of sub-samples in the packed sample
seq_len = Total sequence length P = Number of sub-samples in the packed sample
num_imgs = Number of images across all samples in the packed sample seq_len = Total sequence length
""" num_imgs = Number of images across all samples in the packed sample
"""
__key__: str # Sample name
__restore_key__: Tuple[Union[str, int, tuple], ...] __key__: str # Sample name
__subflavor__: Dict # Sample metadata. Deprecated. __restore_key__: Tuple[Union[str, int, tuple], ...]
__subflavors__: Dict # Sample metadata. __subflavor__: Dict # Sample metadata. Deprecated.
tokens: torch.Tensor # Input tokens packed into a single tensor (seq_len,) __subflavors__: Dict # Sample metadata.
labels: torch.Tensor # Target tokens packed into a single tensor (seq_len,) tokens: torch.Tensor # Input tokens packed into a single tensor (seq_len,)
imgs: List[torch.Tensor] # Input images labels: torch.Tensor # Target tokens packed into a single tensor (seq_len,)
num_tiles: List[int] # Number of tiles for each image of each sample (num_imgs) imgs: List[torch.Tensor] # Input images
max_length: int # Maximum length across sub-samples. num_tiles: List[int] # Number of tiles for each image of each sample (num_imgs)
cu_lengths: List[int] # Cumulative length of each sub-sample in this packed sample incl. text and image tokens (P,) max_length: int # Maximum length across sub-samples.
cu_lengths: List[int] # Cumulative length of each sub-sample in this packed sample incl. text and image tokens (P,)
# Typing for the resulting batch data after encode_batch()
@dataclass # Typing for the resulting batch data after encode_batch()
class ImageTaskBatchPacked(Batch): @dataclass
"""Dataclass to store a batch of packed samples. class ImageTaskBatchPacked(Batch):
"""Dataclass to store a batch of packed samples.
N = Batch size
P = Number of samples in the packed sample N = Batch size
seq_len = Maximum sequence length P = Number of samples in the packed sample
num_imgs = Number of images across all samples in the packed sample seq_len = Maximum sequence length
""" num_imgs = Number of images across all samples in the packed sample
"""
__key__: List[str] # Sample names
__restore_key__: Tuple[Union[str, int, tuple], ...] __key__: List[str] # Sample names
__subflavor__: Dict # Sample metadata. Deprecated. __restore_key__: Tuple[Union[str, int, tuple], ...]
__subflavors__: List[Dict] # Sample metadatas. __subflavor__: Dict # Sample metadata. Deprecated.
tokens: torch.Tensor # Input tokens packed and padded (N, seq_len) __subflavors__: List[Dict] # Sample metadatas.
labels: torch.Tensor # Target tokens packed and padded (N, seq_len) tokens: torch.Tensor # Input tokens packed and padded (N, seq_len)
imgs: torch.Tensor # All image tiles stacked into a single tensor (num_tiles, C, H, W) labels: torch.Tensor # Target tokens packed and padded (N, seq_len)
num_tiles: List[List[int]] # Number of tiles per image (N, num_imgs) imgs: torch.Tensor # All image tiles stacked into a single tensor (num_tiles, C, H, W)
max_lengths: List[int] # Maximum length across sub-samples (N,) num_tiles: List[List[int]] # Number of tiles per image (N, num_imgs)
cu_lengths: List[List[int]] # Cumulative length of each sub-sample in each packed sample of the batch (N, P) max_lengths: List[int] # Maximum length across sub-samples (N,)
cu_lengths: List[List[int]] # Cumulative length of each sub-sample in each packed sample of the batch (N, P)
# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L19
# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. # Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L19
def search_for_fit(numbers: List[int], capacity: int) -> int: # Copyright (c) 2024 LLaMA-Factory. Apache license 2.0.
"""Finds the index of largest number that fits into the knapsack with the given capacity.""" def search_for_fit(numbers: List[int], capacity: int) -> int:
index = bisect.bisect(numbers, capacity) """Finds the index of largest number that fits into the knapsack with the given capacity."""
return -1 if index == 0 else (index - 1) index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L27
# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. # Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L27
def greedy_knapsack(item_sizes: List[int], samples: List, max_capacity: int) -> List: # Copyright (c) 2024 LLaMA-Factory. Apache license 2.0.
"""Greedy algorithm with binary search for the knapsack problem. def greedy_knapsack(item_sizes: List[int], samples: List, max_capacity: int) -> List:
"""Greedy algorithm with binary search for the knapsack problem.
Pack as many samples as possible given a maximum capacity and capacities of individual samples.
Used if sequence packing is enabled. Pack as many samples as possible given a maximum capacity and capacities of individual samples.
""" Used if sequence packing is enabled.
assert len(item_sizes) == len(samples), "sample lengths and samples must have the same length." """
assert len(item_sizes) == len(samples), "sample lengths and samples must have the same length."
knapsacks = []
knapsacks = []
if len(item_sizes) == 0:
return knapsacks if len(item_sizes) == 0:
return knapsacks
# Sort sample lengths and samples together.
sorted_item_sizes, sorted_samples = zip(*sorted(zip(item_sizes, samples), key=lambda x: x[0])) # Sort sample lengths and samples together.
sorted_item_sizes = list(sorted_item_sizes) sorted_item_sizes, sorted_samples = zip(*sorted(zip(item_sizes, samples), key=lambda x: x[0]))
sorted_samples = list(sorted_samples) sorted_item_sizes = list(sorted_item_sizes)
sorted_samples = list(sorted_samples)
# Check if all samples fit in the knapsack capacity.
if sorted_item_sizes[-1] > max_capacity: # Check if all samples fit in the knapsack capacity.
raise ValueError(f"knapsack: A sample is larger {sorted_item_sizes[-1]} than the max_sequence_length {max_capacity}.") if sorted_item_sizes[-1] > max_capacity:
raise ValueError(f"knapsack: A sample is larger {sorted_item_sizes[-1]} than the max_sequence_length {max_capacity}.")
while sorted_item_sizes:
current_knapsack = [] while sorted_item_sizes:
remaining_capacity = max_capacity current_knapsack = []
remaining_capacity = max_capacity
while True:
idx = search_for_fit(sorted_item_sizes, remaining_capacity) while True:
if idx == -1: idx = search_for_fit(sorted_item_sizes, remaining_capacity)
break # Can't fit more samples. if idx == -1:
break # Can't fit more samples.
remaining_capacity -= sorted_item_sizes[idx]
remaining_capacity -= sorted_item_sizes[idx]
sorted_item_sizes.pop(idx)
sample = sorted_samples.pop(idx) sorted_item_sizes.pop(idx)
current_knapsack.append(sample) sample = sorted_samples.pop(idx)
current_knapsack.append(sample)
knapsacks.append(current_knapsack)
knapsacks.append(current_knapsack)
return knapsacks
return knapsacks
class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, dict]):
"""A simple task encoder for VLMs.""" class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, dict]):
"""A simple task encoder for VLMs."""
def __init__(
self def __init__(
): self
super().__init__() ):
super().__init__()
self.args = get_args()
self.args = get_args()
self.tokenizer = get_tokenizer()
with open(self.args.prompt_path, "r") as f: self.tokenizer = get_tokenizer()
self.manual_prompts = json.load(f) with open(self.args.prompt_path, "r") as f:
self.dataloader_seq_length = self.args.dataloader_seq_length # Always return samples of this length. self.manual_prompts = json.load(f)
self.packing_seq_length = self.args.packing_seq_length # Packing sequence length, if packing is enabled. self.dataloader_seq_length = self.args.dataloader_seq_length # Always return samples of this length.
self.is_packing_enabled = self.args.packing_buffer_size is not None and self.args.packing_buffer_size > 0 self.packing_seq_length = self.args.packing_seq_length # Packing sequence length, if packing is enabled.
self.is_packing_enabled = self.args.packing_buffer_size is not None and self.args.packing_buffer_size > 0
if self.dataloader_seq_length and self.packing_seq_length:
assert self.dataloader_seq_length >= self.packing_seq_length, "dataloader sequence length must be greater than or equal to the packing sequence length" if self.dataloader_seq_length and self.packing_seq_length:
assert self.dataloader_seq_length >= self.packing_seq_length, "dataloader sequence length must be greater than or equal to the packing sequence length"
if self.is_packing_enabled:
assert self.packing_seq_length > 0, "packing sequence length must be set" if self.is_packing_enabled:
assert self.packing_seq_length > 0, "packing sequence length must be set"
self.num_image_embeddings_per_tile = get_num_image_embeddings(
self.args.img_h, self.num_image_embeddings_per_tile = get_num_image_embeddings(
self.args.img_w, self.args.img_h,
self.args.patch_dim, self.args.img_w,
self.args.vision_model_type, self.args.patch_dim,
self.args.disable_vision_class_token, self.args.vision_model_type,
1, self.args.disable_vision_class_token,
self.args.pixel_shuffle, 1,
self.args.use_tile_tags, self.args.pixel_shuffle,
) self.args.use_tile_tags,
)
self.txt_to_token_dict = {}
self.txt_to_token_dict = {}
self.img_h, self.img_w = self.args.img_h, self.args.img_w
self.img_h, self.img_w = self.args.img_h, self.args.img_w
# This map is used to reduce the number of tiles used per image if the number of tokens is self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
# larger than the decoder_seq_length. # This map is used to reduce the number of tiles used per image if the number of tokens is
self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1} # larger than the decoder_seq_length.
self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1}
def _get_total_seq_length(self, input_ids, num_tiles):
"""Calculate expected sequence length given text tokens length and number of tiles.""" self.find_closest_aspect_ratio_fn = (
total_num_images = len(num_tiles) find_closest_area_weighted_aspect_ratio if self.args.use_area_weighted_aspect_ratio
total_num_tiles = sum(num_tiles) else find_closest_aspect_ratio)
total_len = len(input_ids) + total_num_tiles * self.num_image_embeddings_per_tile - total_num_images
return total_len def _get_total_seq_length(self, input_ids, num_tiles):
"""Calculate expected sequence length given text tokens length and number of tiles."""
def _truncate_for_packing(self, input_ids, target, num_tiles): total_num_images = len(num_tiles)
"""Truncate tokens and labels if they exceed packing sequence length.""" total_num_tiles = sum(num_tiles)
total_num_images = len(num_tiles) total_len = len(input_ids) + total_num_tiles * self.num_image_embeddings_per_tile - total_num_images
total_num_tiles = sum(num_tiles) return total_len
total_img_embeddings_len = total_num_tiles * self.num_image_embeddings_per_tile
max_text_tokens = self.packing_seq_length - total_img_embeddings_len + total_num_images def _truncate_for_packing(self, input_ids, target, num_tiles):
"""Truncate tokens and labels if they exceed packing sequence length."""
input_ids = input_ids[:max_text_tokens] total_num_images = len(num_tiles)
target = target[:max_text_tokens] total_num_tiles = sum(num_tiles)
total_img_embeddings_len = total_num_tiles * self.num_image_embeddings_per_tile
# If truncate causes all labels to be ignored, then skip the sample max_text_tokens = self.packing_seq_length - total_img_embeddings_len + total_num_images
if (target == IGNORE_INDEX).all():
raise ValueError(f"all targets will be ignored after truncation: {input_ids}") input_ids = input_ids[:max_text_tokens]
target = target[:max_text_tokens]
return input_ids, target
# If truncate causes all labels to be ignored, then skip the sample
@stateless(restore_seeds=True) if (target == IGNORE_INDEX).all():
def encode_sample(self, sample: Union[CaptioningSample, OCRSample, VQASample, SimilarityInterleavedSample]): raise ValueError(f"all targets will be ignored after truncation: {input_ids}")
if isinstance(sample, OCRSample):
if "pdfa" in sample.__key__: return input_ids, target
yield self.combined_ocr_encoder(sample, task_type='encode_pdf')
elif "multi" in sample.__key__: @stateless(restore_seeds=True)
yield self.combined_ocr_encoder(sample, task_type='_encode_ocr') def encode_sample(self, sample: Union[CaptioningSample, OCRSample, VQASample, SimilarityInterleavedSample]):
else: if isinstance(sample, OCRSample):
yield self.combined_ocr_encoder(sample, task_type='encode_ocr_ref') if "pdfa" in sample.__key__:
elif isinstance(sample, CaptioningSample): yield self.combined_ocr_encoder(sample, task_type='encode_pdf')
yield self.encode_captioning(sample) elif "multi" in sample.__key__:
elif isinstance(sample, VQASample): yield self.combined_ocr_encoder(sample, task_type='_encode_ocr')
is_llava_training = sample.__subflavors__["is_llava_training"] if "is_llava_training" in sample.__subflavors__ else False else:
yield self.combined_ocr_encoder(sample, task_type='encode_ocr_ref')
if "llava" in sample.__key__ or is_llava_training: elif isinstance(sample, CaptioningSample):
yield self.encode_llava_pretrain(sample) yield self.encode_captioning(sample)
else: elif isinstance(sample, VQASample):
yield self.encode_any_single_turn_vqa(sample) is_llava_training = sample.__subflavors__["is_llava_training"] if "is_llava_training" in sample.__subflavors__ else False
elif isinstance(sample, SimilarityInterleavedSample):
yield self.encode_llava_sft(sample) if "llava" in sample.__key__ or is_llava_training:
elif isinstance(sample, MultiChoiceVQASample): yield self.encode_llava_pretrain(sample)
yield self.encode_any_single_turn_vqa(sample) else:
else: yield self.encode_any_single_turn_vqa(sample)
raise NotImplementedError("Sample format not supported", sample) elif isinstance(sample, SimilarityInterleavedSample):
yield self.encode_llava_sft(sample)
def encode_captioning(self, sample: CaptioningSample): elif isinstance(sample, MultiChoiceVQASample):
"""Encode CaptioningSample.""" yield self.encode_any_single_turn_vqa(sample)
augment = sample.__subflavors__.get("augmentation") # Because the SampleListSample is defined in the Megatron module but loaded by the Energon
# library, we need to resort to the more brittle check:
imgs = get_visual_transform( elif type(sample).__name__ == "SampleListSample":
sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment, yield self.encode_sample_list(sample)
self.args.vision_model_type, else:
) raise NotImplementedError("Sample format not supported", sample)
num_tiles = [len(imgs)]
def encode_captioning(self, sample: CaptioningSample):
prompt_list = self.manual_prompts["CaptioningPretraining"]["raw"] """Encode CaptioningSample."""
augment = sample.__subflavors__.get("augmentation")
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx] imgs = get_visual_transform(
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + "\n" sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment,
self.args.vision_model_type, find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn
caption = sample.caption.strip() )
num_tiles = [len(imgs)]
split_by_line_flag = sample.__subflavors__.get("SplitByLine")
if split_by_line_flag: prompt_list = self.manual_prompts["CaptioningPretraining"]["raw"]
caption_list = caption.split('\n')
caption = np.random.choice(caption_list) prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
conv = [ cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + "\n"
# Note: no system message.
{"role": "user", "content": cur_prompt}, caption = sample.caption.strip()
{"role": "assistant", "content": caption},
] split_by_line_flag = sample.__subflavors__.get("SplitByLine")
if split_by_line_flag:
input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False) caption_list = caption.split('\n')
caption = np.random.choice(caption_list)
if self.is_packing_enabled:
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) conv = [
# Note: no system message.
return ImageTaskSample( {"role": "user", "content": cur_prompt},
__key__=sample.__key__, {"role": "assistant", "content": caption},
__restore_key__=sample.__restore_key__, ]
__subflavor__=None,
__subflavors__=sample.__subflavors__, input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False)
imgs=imgs,
num_tiles=num_tiles, if self.is_packing_enabled:
tokens=torch.tensor(input_ids), input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)
labels=torch.tensor(target),
total_len=self._get_total_seq_length(input_ids, num_tiles), return ImageTaskSample(
) __key__=sample.__key__,
__restore_key__=sample.__restore_key__,
def encode_llava_pretrain(self, sample: VQASample): __subflavor__=None,
"""Encode pretrain sample in LLAVA style.""" __subflavors__=sample.__subflavors__,
augment = sample.__subflavors__.get("augmentation", False) imgs=imgs,
num_tiles=num_tiles,
imgs = get_visual_transform( tokens=torch.tensor(input_ids),
sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment, labels=torch.tensor(target),
self.args.vision_model_type, total_len=self._get_total_seq_length(input_ids, num_tiles),
) )
num_tiles = [len(imgs)]
def encode_llava_pretrain(self, sample: VQASample):
# LLAVA training: override text-prompt with just the image. """Encode pretrain sample in LLAVA style."""
conv = [ augment = sample.__subflavors__.get("augmentation", False)
# Note: no system message.
{"role": "user", "content": IMAGE_TOKEN + "\n"}, imgs = get_visual_transform(
{"role": "assistant", "content": sample.answers}, sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment,
] self.args.vision_model_type, find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn
)
input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False) num_tiles = [len(imgs)]
if self.is_packing_enabled: # LLAVA training: override text-prompt with just the image.
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) conv = [
# Note: no system message.
return ImageTaskSample( {"role": "user", "content": IMAGE_TOKEN + "\n"},
__key__=sample.__key__, {"role": "assistant", "content": sample.answers},
__restore_key__=sample.__restore_key__, ]
__subflavor__=None,
__subflavors__=sample.__subflavors__, input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False)
imgs=imgs,
num_tiles=num_tiles, if self.is_packing_enabled:
tokens=torch.tensor(input_ids), input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)
labels=torch.tensor(target),
total_len=self._get_total_seq_length(input_ids, num_tiles), return ImageTaskSample(
) __key__=sample.__key__,
__restore_key__=sample.__restore_key__,
def encode_llava_sft(self, sample: SimilarityInterleavedSample): __subflavor__=None,
"""Encode SFT sample.""" __subflavors__=sample.__subflavors__,
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False imgs=imgs,
has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False num_tiles=num_tiles,
tokens=torch.tensor(input_ids),
has_image = False labels=torch.tensor(target),
if hasattr(sample, "images"): total_len=self._get_total_seq_length(input_ids, num_tiles),
# If this is a text-only sample and we are freezing the LM, )
# then use a dummy input image.
if len(sample.images) == 0 and self.args.freeze_LM: def encode_sample_list(self, samples: SampleListSample):
empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255)) """We encode the list of samples using encode_llava_sft on each sample."""
sample.images.append(empty_img) error_msg = ("You probably don't want to use online packing since SampleListSample is "
if len(sample.images) > 0 and not has_video: "usually used along offline packing.")
has_image = True assert not self.is_packing_enabled, error_msg
encoded_samples = []
# Note: Some tokenizers may ignore the system prompt. current_length = 0
conversation = [{"role": "system", "content": "Answer the questions."}] for sample in samples.samples:
# Format the conversation as a list of "user" / "assistant" turns. encoded_sample = self.encode_llava_sft(sample, truncate_for_sample_list_packing=True)
for text in sample.texts: if current_length + encoded_sample.total_len > self.packing_seq_length:
error_msg = f"unexpected role {text['from']} in {sample.texts}" break
assert text["from"] in ["human", "gpt"], error_msg else:
conversation.append({ encoded_samples.append(encoded_sample)
"role": "user" if text["from"] == "human" else "assistant", current_length += encoded_sample.total_len
"content": text["value"]}) return self.pack_selected_samples(encoded_samples)
# Replace the image tags <image-idx> with IMAGE_TOKEN and count the number of image tags def encode_llava_sft(self, sample: Union[SimilarityInterleavedSample, OfflineTargetAspectRatioSample], truncate_for_sample_list_packing=False):
number_image_tags = 0 """Encode SFT sample."""
image_tag_ids_list = [] augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
for turn in conversation: has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False
if turn["role"] == "user":
image_tag_ids = [int(x) - 1 for x in re.findall(r"<image-(\d+)>", turn["content"])] # If the target aspect ratio are provided by the dataset, we use them instead of computing
image_tag_ids_list.extend(image_tag_ids) # them with the self.find_closest_aspect_ratio_fn function.
turn["content"] = re.sub(r"<image-\d+>", IMAGE_TOKEN, turn["content"]) local_find_closest_aspect_ratio_fn = self.find_closest_aspect_ratio_fn
number_image_tags += turn["content"].count(IMAGE_TOKEN) if type(sample).__name__ == "OfflineTargetAspectRatioSample":
# For videos, we replace the image tag with the video tag target_aspect_ratio = tuple(sample.target_aspect_ratio[0])
if has_video: assert target_aspect_ratio is not None, "Sample of type OfflineTargetAspectRatioSample needs to define the target aspect ratio."
turn["content"] = turn["content"].replace(IMAGE_TOKEN, VIDEO_TOKEN) local_find_closest_aspect_ratio_fn = lambda *args, **kwargs: target_aspect_ratio
# We re-order the images in sample.images according to how they appear in the conversation. has_image = False
if len(image_tag_ids_list) > 0: # We infer whether the sample has image or not.
sample.images = [sample.images[idx] for idx in image_tag_ids_list] if hasattr(sample, "images") and not has_video:
# If this is a text-only sample and we are freezing the LM,
# If there is only one image, but several image tags, we assume all the tags refer to the # then use a dummy input image.
# same image and duplicate the image: if len(sample.images) == 0 and self.args.freeze_LM:
if len(sample.images) == 1 and number_image_tags > 1: empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255))
sample.images = sample.images * number_image_tags sample.images.append(empty_img)
if len(sample.images) > 0:
number_of_images = len(sample.images) has_image = True
# Fail if there are more image or video tags than image or videos:
error_msg = ( # Note: Some tokenizers may ignore the system prompt.
f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}") conversation = [{"role": "system", "content": "Answer the questions."}]
assert number_image_tags <= number_of_images, error_msg # Format the conversation as a list of "user" / "assistant" turns.
for text in sample.texts:
# If there are less image of video tags than image or videos, prepend the tags to the first error_msg = f"unexpected role {text['from']} in {sample.texts}"
# user message: assert text["from"] in ["human", "gpt"], error_msg
if number_image_tags < number_of_images: conversation.append({
for turn in conversation: "role": "user" if text["from"] == "human" else "assistant",
if turn["role"] == "user": "content": text["value"]})
tag_to_add = VIDEO_TOKEN if has_video else IMAGE_TOKEN
turn["content"] = tag_to_add*(number_of_images-number_image_tags) + "\n" + turn["content"] # Replace the image tags <image-idx> with IMAGE_TOKEN and count the number of image tags
break number_image_tags = 0
image_tag_ids_list = []
input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) for turn in conversation:
if turn["role"] == "user":
if has_image: image_tag_ids = [int(x) - 1 for x in re.findall(r"<image-(\d+)>", turn["content"])]
imgs = [] image_tag_ids_list.extend(image_tag_ids)
num_tiles = [] turn["content"] = re.sub(r"<image-\d+>", IMAGE_TOKEN, turn["content"])
max_num_tiles = self.args.max_num_tiles # For videos, we use the image token to locate where to put the frames.
# We keep a buffer of 4 tokens for the question, if has_video:
# the rest can be used for image tokens. turn["content"] = turn["content"].replace(VIDEO_TOKEN, IMAGE_TOKEN)
max_image_token_allowed = self.args.decoder_seq_length - len(input_ids) - 4 number_image_tags += turn["content"].count(IMAGE_TOKEN)
# We start by extracting as many tiles per image as possible, and decrease the max
# number of tiles if there are too many image tokens. # We re-order the images in sample.images according to how they appear in the conversation.
while True: if len(image_tag_ids_list) > 0:
imgs = [] sample.images = [sample.images[idx] for idx in image_tag_ids_list]
num_tiles = []
for img in sample.images: # If there is only one image, but several image tags, we assume all the tags refer to the
img_tiles = get_visual_transform( # same image and duplicate the image:
img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles, if not has_video and len(sample.images) == 1 and number_image_tags > 1:
self.args.use_thumbnail, augment, self.args.vision_model_type) sample.images = sample.images * number_image_tags
imgs += img_tiles
num_tiles += [len(img_tiles)] # We currently only support one video per sample.
if max_num_tiles == 1: number_of_images = 1 if has_video else len(sample.images)
break # Fail if there are more image or video tags than image or videos:
if sum(num_tiles) * self.token_per_img_tile > max_image_token_allowed: error_msg = (
if max_num_tiles in self.num_tiles_degradation_map: f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}")
max_num_tiles = self.num_tiles_degradation_map[max_num_tiles] assert number_image_tags <= number_of_images, error_msg
else:
raise RuntimeError(( # If there are less image of video tags than image or videos, prepend the tags to the first
f"Tried to decrease the number of tiles {max_num_tiles} but it's not ", # user message:
f"defined in the degradation map {self.num_tiles_degradation_map}")) if number_image_tags < number_of_images:
else: for turn in conversation:
break if turn["role"] == "user":
elif has_video: turn["content"] = IMAGE_TOKEN*(number_of_images-number_image_tags) + "\n" + turn["content"]
# We don't use tiling for videos to limit the number of tokens. break
use_tiling=False
# Grab the selected frames of the video as a tensor with shape input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False)
# fhwc: (num_frames, num_channels, height, width).
video_fchw = sample.images[0].permute(0, 1, 2, 3) if has_image:
selected_frames = torch.linspace( imgs = []
0, video_fchw.shape[0] - 1, self.args.num_frames).long() num_tiles = []
video_fchw = video_fchw[selected_frames] max_num_tiles = self.args.max_num_tiles
imgs = [] # We keep a buffer of 4 tokens for the question,
for video_chw in video_fchw: # the rest can be used for image tokens.
to_pil = ToPILImage() max_image_token_allowed = self.args.decoder_seq_length - len(input_ids) - 4
video_chw = to_pil(video_chw) # We start by extracting as many tiles per image as possible, and decrease the max
imgs += get_visual_transform( # number of tiles if there are too many image tokens.
video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles, while True:
self.args.use_thumbnail, augment, self.args.vision_model_type) imgs = []
num_tiles = [len(imgs)] num_tiles = []
else: for img in sample.images:
imgs = num_tiles = [] img_tiles = get_visual_transform(
img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles,
if self.is_packing_enabled: self.args.use_thumbnail, augment, self.args.vision_model_type,
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) find_closest_aspect_ratio_fn=local_find_closest_aspect_ratio_fn)
imgs += img_tiles
# Some final checks with respect to the number of image tokens and images on the tokenized num_tiles += [len(img_tiles)]
# conversation. There can still be errors, for instance if a non-video sample happens to if max_num_tiles == 1:
# have our pre-defined video token, or if the packing truncation removed a necessary image break
# tag. if sum(num_tiles) * self.num_image_embeddings_per_tile > max_image_token_allowed:
number_image_token = np.sum(input_ids == self.img_token_id) if max_num_tiles in self.num_tiles_degradation_map:
error_msg = ( max_num_tiles = self.num_tiles_degradation_map[max_num_tiles]
f"Found {number_image_token} image tokens for len({num_tiles}) = {len(num_tiles)} image tiles in {conversation}.") else:
assert number_image_token == len(num_tiles), error_msg raise RuntimeError((
error_msg = ( f"Tried to decrease the number of tiles {max_num_tiles} but it's not ",
f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.") f"defined in the degradation map {self.num_tiles_degradation_map}"))
assert np.sum(num_tiles) == len(imgs), error_msg else:
break
return ImageTaskSample( elif has_video:
__key__=sample.__key__, # We don't use tiling for videos to limit the number of tokens.
__restore_key__=sample.__restore_key__, use_tiling=False
__subflavor__=None, # Grab the selected frames of the video as a tensor with shape
__subflavors__=sample.__subflavors__, # fhwc: (num_frames, num_channels, height, width).
imgs=imgs, video_fchw = sample.images.frames
num_tiles=num_tiles, if video_fchw.shape[0] == 0:
tokens=torch.tensor(input_ids), raise ValueError(f"Video {sample.__key__} {sample.__restore_key__} {sample.texts} has no frames.")
labels=torch.tensor(target), selected_frames = torch.linspace(
total_len=self._get_total_seq_length(input_ids, num_tiles), 0, video_fchw.shape[0] - 1, self.args.num_frames).long()
) video_fchw = video_fchw[selected_frames]
imgs = []
def encode_any_single_turn_vqa(self, sample): for video_chw in video_fchw:
"""Encode MultiChoiceVQA or VQA sample.""" to_pil = ToPILImage()
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False video_chw = to_pil(video_chw)
has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False imgs += get_visual_transform(
video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles,
if has_video: self.args.use_thumbnail, augment, self.args.vision_model_type,
# Grab the selected frames of the video as a tensor with shape find_closest_aspect_ratio_fn=local_find_closest_aspect_ratio_fn)
# fhwc: (num_frames, height, width, num_channels). num_tiles = [len(imgs)]
video_fhwc = sample.image.permute(0, 2, 3, 1) else:
selected_frames = torch.linspace( imgs = num_tiles = []
0, video_fhwc.shape[0] - 1, self.args.num_frames).long()
video_frame_fhwc = video_fhwc[selected_frames] if self.is_packing_enabled or truncate_for_sample_list_packing:
imgs = [] input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)
for video_frame_hwc in video_frame_fhwc:
imgs += get_visual_transform( # Some final checks with respect to the number of image tokens and images on the tokenized
video_frame_hwc, self.img_h, self.img_w, # conversation. There can still be errors, for instance if a non-video sample happens to
self.args.use_tiling, self.args.max_num_tiles, # have our pre-defined video token, or if the packing truncation removed a necessary image
self.args.use_thumbnail, augment, self.args.vision_model_type) # tag.
else: number_image_token = np.sum(input_ids == self.img_token_id)
imgs = get_visual_transform( error_msg = (
sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, f"Found {number_image_token} image tokens for len({num_tiles}) = {len(num_tiles)} image tiles in {conversation}.")
self.args.use_thumbnail, augment, self.args.vision_model_type, assert number_image_token == len(num_tiles), error_msg
) error_msg = (
f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.")
num_tiles = [len(imgs)] assert np.sum(num_tiles) == len(imgs), error_msg
if isinstance(sample, MultiChoiceVQASample): # We need to ensure that there are at least some trainable tokens in the sample.
cur_prompt = format_multichoice_question(sample.context, sample.choices) assert self.target_has_trainable_tokens(input_ids, num_tiles, target), "Sample has no trainable tokens."
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt return ImageTaskSample(
cur_answer = format_multichoice_answer(sample.correct_choice_idx) __key__=sample.__key__,
elif isinstance(sample, VQASample): __restore_key__=sample.__restore_key__,
if 'docvqa' in sample.__key__: __subflavor__=None,
prompt_list = self.manual_prompts["VQASFT"]["docvqa"] __subflavors__=sample.__subflavors__,
elif sample.__subflavors__.get("VQASFT"): imgs=imgs,
prompt_list = self.manual_prompts["VQASFT"]["raw"] num_tiles=num_tiles,
else: tokens=torch.tensor(input_ids),
prompt_list = ["{}"] labels=torch.tensor(target),
total_len=self._get_total_seq_length(input_ids, num_tiles),
prompt_idx = np.random.randint(len(prompt_list)) )
cur_prompt = prompt_list[prompt_idx]
def target_has_trainable_tokens(self, input_ids, num_tiles, target):
cur_prompt = cur_prompt.format(sample.context) # Compute the loss mask based on extending the image tags with the proper
# number of image tokens, extracting the first self.args.decoder_seq_length tokens, and
if IMAGE_TOKEN not in cur_prompt: # ensuring that some of these tokens have a loss mask > 0.
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt # Note that this is a bit hacky because we reproduce here parts of the logics which are in
# the model itself. Ideally, the data sampler would return the already processed inputs
if isinstance(sample.answers, list): # and targets to avoid this duplication.
answer_list = sample.answers expanded_target = target.copy()
weight_list = np.array(sample.answer_weights).astype(np.float32) expanded_target[input_ids==self.img_token_id] = self.img_token_id
weight_list = weight_list / np.sum(weight_list) expanded_target = self.replace_value_with_repetition(
answer_idx = np.random.choice(weight_list.shape[0], 1, p=weight_list)[0] expanded_target, self.img_token_id,
cur_answer = answer_list[answer_idx] self.num_image_embeddings_per_tile * np.array(num_tiles), IGNORE_INDEX)
else: loss_mask = torch.ones(torch.tensor(expanded_target).size(), dtype=torch.float)
cur_answer = sample.answers loss_mask[expanded_target == self.tokenizer.pad] = 0.0 # mask paddings
else: loss_mask[expanded_target == IGNORE_INDEX] = 0.0 # mask prompts
raise NotImplementedError("Unsupported data type provided", sample) loss_mask = torch.cat((loss_mask[1:], torch.zeros((1,))))
loss_mask = loss_mask[:self.args.decoder_seq_length]
conversation = [ return torch.sum(loss_mask) > 0
{"role": "system", "content": "Answer the questions."},
{"role": "user", "content": cur_prompt}, def replace_value_with_repetition(self, arr, token_to_replace, num_repetition, new_token):
{"role": "assistant", "content": str(cur_answer)}, """
] Replace every occurrence of value V in the input array with R repetitions of W.
input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) Args:
arr (Array): Input array to be modified
if self.is_packing_enabled: token_to_replace: token to be replaced
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) new_token: new token
num_repetition (Array): number of repetition of new token.
return ImageTaskSample(
__key__=sample.__key__, Returns:
__restore_key__=sample.__restore_key__, Array: New array with token_to_replace replaced by num_repetition repetitions of
__subflavor__=None, new_token
__subflavors__=sample.__subflavors__, """
imgs=imgs, error_msg = "The number of image tokens must match the length of the tile tensor."
num_tiles=num_tiles, assert np.sum(arr==token_to_replace) == len(num_repetition), error_msg
tokens=torch.tensor(input_ids), result = []
labels=torch.tensor(target), idx = 0
total_len=self._get_total_seq_length(input_ids, num_tiles), for item in arr:
) if item == token_to_replace:
# If the current item matches token_to_replace, add R copies of W
def combined_ocr_encoder(self, sample, task_type): result.extend([new_token] * num_repetition[idx])
"""Encode OCR samples.""" idx += 1
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False else:
# Otherwise, keep the original item
if task_type == "encode_pdf": result.append(item)
sample, cur_prompt, cur_answer = self.encode_pdf_prompt(sample)
elif task_type == "encode_ocr_ref": return np.array(result)
sample, cur_prompt, cur_answer = self.encode_ocr_ref_prompt(sample)
elif task_type == "_encode_ocr": def encode_any_single_turn_vqa(self, sample):
sample, cur_prompt, cur_answer = self.encode_ocr_prompt(sample) """Encode MultiChoiceVQA or VQA sample."""
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
imgs = get_visual_transform( has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False
sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type, if has_video:
) # Grab the selected frames of the video as a tensor with shape
num_tiles = [len(imgs)] # fhwc: (num_frames, height, width, num_channels).
video_fhwc = sample.image.permute(0, 2, 3, 1)
conversation = [ selected_frames = torch.linspace(
{"role": "system", "content": "Answer the questions."}, 0, video_fhwc.shape[0] - 1, self.args.num_frames).long()
{"role": "user", "content": cur_prompt}, video_frame_fhwc = video_fhwc[selected_frames]
{"role": "assistant", "content": str(cur_answer)}, imgs = []
] for video_frame_hwc in video_frame_fhwc:
imgs += get_visual_transform(
input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) video_frame_hwc, self.img_h, self.img_w,
self.args.use_tiling, self.args.max_num_tiles,
if self.is_packing_enabled: self.args.use_thumbnail, augment, self.args.vision_model_type,
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn)
else:
return ImageTaskSample( imgs = get_visual_transform(
__key__=sample.__key__, sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles,
__restore_key__=sample.__restore_key__, self.args.use_thumbnail, augment, self.args.vision_model_type,
__subflavor__=None, find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn
__subflavors__=sample.__subflavors__, )
imgs=imgs,
num_tiles=num_tiles, num_tiles = [len(imgs)]
tokens=torch.tensor(input_ids),
labels=torch.tensor(target), if isinstance(sample, MultiChoiceVQASample):
total_len=self._get_total_seq_length(input_ids, num_tiles), cur_prompt = format_multichoice_question(sample.context, sample.choices)
) if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
def encode_pdf_prompt(self, sample: OCRSample) -> ImageTaskSample: cur_answer = format_multichoice_answer(sample.correct_choice_idx)
"""Encode OCR sample.""" elif isinstance(sample, VQASample):
prompt_list = self.manual_prompts["DocPretraining"]["raw"] if 'docvqa' in sample.__key__:
prompt_idx = np.random.randint(len(prompt_list)) prompt_list = self.manual_prompts["VQASFT"]["docvqa"]
cur_prompt = prompt_list[prompt_idx] elif sample.__subflavors__.get("VQASFT"):
if IMAGE_TOKEN not in cur_prompt: prompt_list = self.manual_prompts["VQASFT"]["raw"]
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt else:
prompt_list = ["{}"]
# Make sure there is no extra IMAGE_TOKEN tag.
sample.text = sample.text.replace(IMAGE_TOKEN, "") prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
caption = sample.text.strip()
cur_prompt = cur_prompt.format(sample.context)
split_by_line_flag = sample.__subflavors__.get("SplitByLine")
if split_by_line_flag: if IMAGE_TOKEN not in cur_prompt:
caption_list = caption.split('\n') cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
caption = np.random.choice(caption_list)
cur_answer = caption if isinstance(sample.answers, list):
answer_list = sample.answers
return sample, cur_prompt, cur_answer weight_list = np.array(sample.answer_weights).astype(np.float32)
weight_list = weight_list / np.sum(weight_list)
def encode_ocr_ref_prompt(self, sample: OCRSample) -> ImageTaskSample: answer_idx = np.random.choice(weight_list.shape[0], 1, p=weight_list)[0]
"""Encode OCR sample.""" cur_answer = answer_list[answer_idx]
ref = sample.text else:
region = sample.words_boxes cur_answer = sample.answers
else:
# Make sure there is no extra IMAGE_TOKEN tag raise NotImplementedError("Unsupported data type provided", sample)
ref = ref.replace(IMAGE_TOKEN, "")
conversation = [
if len(region) == 4: {"role": "system", "content": "Answer the questions."},
region = f"<box>({region[0]},{region[1]}),({region[2]},{region[3]})</box>" {"role": "user", "content": cur_prompt},
else: {"role": "assistant", "content": str(cur_answer)},
region = f"<quad>({region[0]},{region[1]}),({region[2]},{region[3]}),({region[4]},{region[5]}),({region[6]},{region[7]})</quad>" ]
# Randomly choose between two tasks input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False)
task_idx = np.random.randint(2)
if task_idx == 0: if self.is_packing_enabled:
# Referring Grounding input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)
prompt_list = self.manual_prompts["DocPretraining"]["referring_grounding"]
prompt_content = ref return ImageTaskSample(
answer = region __key__=sample.__key__,
else: __restore_key__=sample.__restore_key__,
# Grounded OCR __subflavor__=None,
prompt_list = self.manual_prompts["DocPretraining"]["grounded_ocr"] __subflavors__=sample.__subflavors__,
prompt_content = region imgs=imgs,
answer = ref num_tiles=num_tiles,
tokens=torch.tensor(input_ids),
prompt_idx = np.random.randint(len(prompt_list)) labels=torch.tensor(target),
cur_prompt = prompt_list[prompt_idx] total_len=self._get_total_seq_length(input_ids, num_tiles),
cur_prompt = cur_prompt.format(prompt_content) )
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt def combined_ocr_encoder(self, sample, task_type):
"""Encode OCR samples."""
return sample, cur_prompt, answer augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
def bbox_coord_to_label(self, text, bbox): if task_type == "encode_pdf":
"""Format bbox coordinates as text.""" sample, cur_prompt, cur_answer = self.encode_pdf_prompt(sample)
assert len(bbox) == 4 or len(bbox) == 8 elif task_type == "encode_ocr_ref":
sample, cur_prompt, cur_answer = self.encode_ocr_ref_prompt(sample)
# Make sure there is no extra IMAGE_TOKEN tag elif task_type == "_encode_ocr":
text = text.replace(IMAGE_TOKEN, "") sample, cur_prompt, cur_answer = self.encode_ocr_prompt(sample)
if len(bbox) == 4: imgs = get_visual_transform(
label_str = f"<ref>{text}</ref><box>({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})</box>" sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles,
else: self.args.use_thumbnail, augment, self.args.vision_model_type,
label_str = f"<ref>{text}</ref><quad>({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]}),({bbox[4]},{bbox[5]}),({bbox[6]},{bbox[7]})</quad>" find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn
)
return label_str num_tiles = [len(imgs)]
def encode_ocr_prompt(self, sample: OCRSample) -> ImageTaskSample: conversation = [
"""Encode OCR sample.""" {"role": "system", "content": "Answer the questions."},
if isinstance(sample.words_boxes[0], int): {"role": "user", "content": cur_prompt},
answer = self.bbox_coord_to_label(sample.text, sample.words_boxes) {"role": "assistant", "content": str(cur_answer)},
elif isinstance(sample.words_boxes[0], list): ]
answer = ""
for i, bbox in enumerate(sample.words_boxes): input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False)
answer += self.bbox_coord_to_label(sample.words_text[i], bbox)
if self.is_packing_enabled:
prompt_list = self.manual_prompts["DocPretraining"]["ocr_multi"] input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx] return ImageTaskSample(
__key__=sample.__key__,
if IMAGE_TOKEN not in cur_prompt: __restore_key__=sample.__restore_key__,
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt __subflavor__=None,
cur_answer = answer __subflavors__=sample.__subflavors__,
imgs=imgs,
return sample, cur_prompt, cur_answer num_tiles=num_tiles,
tokens=torch.tensor(input_ids),
def batch(self, samples: List[Union[ImageTaskSample, ImageTaskSamplePacked]]) -> ImageTaskBatchPacked: labels=torch.tensor(target),
# Stack images to [num_tiles, c, h, w]. If there are no images (text-only), then use a dummy image. total_len=self._get_total_seq_length(input_ids, num_tiles),
imgs = [img for s in samples for img in s.imgs] )
if len(imgs) > 0:
imgs = torch.stack(imgs) def encode_pdf_prompt(self, sample: OCRSample) -> ImageTaskSample:
else: """Encode OCR sample."""
imgs = torch.tensor([[0]], dtype=torch.float32) prompt_list = self.manual_prompts["DocPretraining"]["raw"]
prompt_idx = np.random.randint(len(prompt_list))
# If the user hasn't defined a target dataloader sequence length, then use the max along the sample lengths. cur_prompt = prompt_list[prompt_idx]
max_seq_len = self.dataloader_seq_length if IMAGE_TOKEN not in cur_prompt:
if not max_seq_len: cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
max_seq_len = max(len(s.tokens) for s in samples)
# Make sure there is no extra IMAGE_TOKEN tag.
tokens = np.full((len(samples), max_seq_len), self.tokenizer.pad, dtype=np.int64) sample.text = sample.text.replace(IMAGE_TOKEN, "")
# +1 to accommodate shift to left by one later.
labels = np.full((len(samples), max_seq_len + 1), self.tokenizer.pad, dtype=np.int64) caption = sample.text.strip()
for i, s in enumerate(samples): split_by_line_flag = sample.__subflavors__.get("SplitByLine")
# If the sample/target length exceeds the target sequence length, then truncate. if split_by_line_flag:
text_len = min(max_seq_len, len(s.tokens)) caption_list = caption.split('\n')
target_len = min(max_seq_len+1, len(s.labels)) caption = np.random.choice(caption_list)
cur_answer = caption
tokens[i, :text_len] = s.tokens[:text_len]
labels[i, :target_len] = s.labels[:target_len] return sample, cur_prompt, cur_answer
num_tiles = torch.tensor([n for s in samples for n in s.num_tiles], dtype=torch.int32) def encode_ocr_ref_prompt(self, sample: OCRSample) -> ImageTaskSample:
if len(num_tiles) == 0: """Encode OCR sample."""
num_tiles = torch.tensor([[0]], dtype=torch.int32) ref = sample.text
region = sample.words_boxes
# Cumulative sample lengths are needed for packing, otherwise use dummy values.
cu_lengths = torch.tensor([[0]], dtype=torch.int32) # Make sure there is no extra IMAGE_TOKEN tag
max_lengths = torch.tensor([[0]], dtype=torch.int32) ref = ref.replace(IMAGE_TOKEN, "")
if self.is_packing_enabled: if len(region) == 4:
cu_lengths = torch.stack([s.cu_lengths for s in samples]) region = f"<box>({region[0]},{region[1]}),({region[2]},{region[3]})</box>"
max_lengths = torch.tensor([s.max_length for s in samples], dtype=torch.int32) else:
region = f"<quad>({region[0]},{region[1]}),({region[2]},{region[3]}),({region[4]},{region[5]}),({region[6]},{region[7]})</quad>"
return ImageTaskBatchPacked(
__key__=[s.__key__ for s in samples], # Randomly choose between two tasks
__restore_key__=[s.__restore_key__ for s in samples], task_idx = np.random.randint(2)
__subflavor__=None, if task_idx == 0:
__subflavors__=samples[0].__subflavors__, # Referring Grounding
tokens=tokens, prompt_list = self.manual_prompts["DocPretraining"]["referring_grounding"]
labels=labels, prompt_content = ref
imgs=imgs, answer = region
num_tiles=num_tiles, else:
cu_lengths=cu_lengths, # Grounded OCR
max_lengths=max_lengths, prompt_list = self.manual_prompts["DocPretraining"]["grounded_ocr"]
) prompt_content = region
answer = ref
def encode_batch(self, batch: ImageTaskBatchPacked) -> dict:
raw = dataclasses.asdict(batch) prompt_idx = np.random.randint(len(prompt_list))
del raw["__subflavors__"] cur_prompt = prompt_list[prompt_idx]
return raw cur_prompt = cur_prompt.format(prompt_content)
if IMAGE_TOKEN not in cur_prompt:
def select_samples_to_pack(self, samples: List[ImageTaskSample]) -> List[List[ImageTaskSample]]: cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
"""Selects which samples will be packed together.
return sample, cur_prompt, answer
NOTE: Energon dataloader calls this method internally if packing is used.
Please see https://nvidia.github.io/Megatron-Energon/packing.html def bbox_coord_to_label(self, text, bbox):
""" """Format bbox coordinates as text."""
lengths = [sample.total_len for sample in samples] assert len(bbox) == 4 or len(bbox) == 8
packed_samples = greedy_knapsack(lengths, samples, self.packing_seq_length) # Make sure there is no extra IMAGE_TOKEN tag
text = text.replace(IMAGE_TOKEN, "")
return packed_samples
if len(bbox) == 4:
@stateless label_str = f"<ref>{text}</ref><box>({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})</box>"
def pack_selected_samples(self, samples: List[ImageTaskSample]) -> List[ImageTaskSamplePacked]: else:
""" label_str = f"<ref>{text}</ref><quad>({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]}),({bbox[4]},{bbox[5]}),({bbox[6]},{bbox[7]})</quad>"
Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked.
return label_str
NOTE: Energon dataloader calls this method internally if packing is used.
Please see https://nvidia.github.io/Megatron-Energon/packing.html def encode_ocr_prompt(self, sample: OCRSample) -> ImageTaskSample:
"""Encode OCR sample."""
Args: if isinstance(sample.words_boxes[0], int):
samples: List of ImageTaskSample instances to pack into one sample. answer = self.bbox_coord_to_label(sample.text, sample.words_boxes)
elif isinstance(sample.words_boxes[0], list):
Returns: answer = ""
ImageTaskSamplePacked instance. for i, bbox in enumerate(sample.words_boxes):
""" answer += self.bbox_coord_to_label(sample.words_text[i], bbox)
packing_seq_len = self.packing_seq_length
prompt_list = self.manual_prompts["DocPretraining"]["ocr_multi"]
packed_tokens = [] prompt_idx = np.random.randint(len(prompt_list))
packed_labels = [] cur_prompt = prompt_list[prompt_idx]
packed_imgs = []
if IMAGE_TOKEN not in cur_prompt:
current_length = 0 cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
max_length = 0 cur_answer = answer
cu_lengths = [0]
return sample, cur_prompt, cur_answer
# Process each sample and build lists that we will concatenate to create the packed sample.
for _, sample in enumerate(samples): def batch(self, samples: List[Union[ImageTaskSample, ImageTaskSamplePacked]]) -> ImageTaskBatchPacked:
sample_len = sample.total_len # Stack images to [num_tiles, c, h, w]. If there are no images (text-only), then use a dummy image.
imgs = [img for s in samples for img in s.imgs]
if sample_len > max_length: if len(imgs) > 0:
max_length = sample_len imgs = torch.stack(imgs)
else:
# If adding this sample exceeds the max length, stop. imgs = torch.tensor([[0]], dtype=torch.float32)
# This should not happen. The select_samples_to_pack method should have already ensured that the samples fit.
if current_length + sample_len > packing_seq_len: # If the user hasn't defined a target dataloader sequence length, then use the max along the sample lengths.
raise ValueError(f"Packed sample exceeds the maximum sequence length of {packing_seq_len}: {samples}") max_seq_len = self.dataloader_seq_length
if not max_seq_len:
# Add the sample's tokens and labels max_seq_len = max(len(s.tokens) for s in samples)
packed_tokens.append(sample.tokens)
packed_labels.append(sample.labels) tokens = np.full((len(samples), max_seq_len), self.tokenizer.pad, dtype=np.int64)
# +1 to accommodate shift to left by one later.
# Add the images labels = np.full((len(samples), max_seq_len + 1), self.tokenizer.pad, dtype=np.int64)
packed_imgs += sample.imgs
for i, s in enumerate(samples):
current_length += sample_len # If the sample/target length exceeds the target sequence length, then truncate.
cu_lengths.append(current_length) text_len = min(max_seq_len, len(s.tokens))
target_len = min(max_seq_len+1, len(s.labels))
# Concatenate packed tokens and labels.
packed_tokens = torch.cat(packed_tokens, dim=0) tokens[i, :text_len] = s.tokens[:text_len]
packed_labels = torch.cat(packed_labels, dim=0) labels[i, :target_len] = s.labels[:target_len]
return ImageTaskSamplePacked( num_tiles = torch.tensor([n for s in samples for n in s.num_tiles], dtype=torch.int32)
__key__=",".join([s.__key__ for s in samples]), if len(num_tiles) == 0:
__restore_key__=(), # Will be set by energon based on `samples` num_tiles = torch.tensor([[0]], dtype=torch.int32)
__subflavor__=None,
__subflavors__=samples[0].__subflavors__, # Cumulative sample lengths are needed for packing, otherwise use dummy values.
tokens=packed_tokens, cu_lengths = torch.tensor([[0]], dtype=torch.int32)
labels=packed_labels, max_lengths = torch.tensor([[0]], dtype=torch.int32)
imgs=packed_imgs,
cu_lengths=torch.tensor(cu_lengths, dtype=torch.int32), if self.is_packing_enabled:
max_length=max_length, cu_lengths = torch.stack([s.cu_lengths for s in samples])
num_tiles=[n for s in samples for n in s.num_tiles], max_lengths = torch.tensor([s.max_length for s in samples], dtype=torch.int32)
)
return ImageTaskBatchPacked(
__key__=[s.__key__ for s in samples],
def print_error_handler(exc: Exception, key: Optional[str]): __restore_key__=[s.__restore_key__ for s in samples],
print( __subflavor__=None,
f"The following exception occurred in the dataloader for sample {key} and is skipped", __subflavors__=samples[0].__subflavors__,
file=sys.stderr, tokens=tokens,
) labels=labels,
traceback.print_exc() imgs=imgs,
num_tiles=num_tiles,
cu_lengths=cu_lengths,
def format_multichoice_question(question, multichoice_options): max_lengths=max_lengths,
"""Format multi-choice question.""" )
options_text = ["{}. {}\n".format(chr(ord('A') + i), option) for i, option in
zip(range(len(multichoice_options)), multichoice_options)] def encode_batch(self, batch: ImageTaskBatchPacked) -> dict:
options_text = "".join(options_text) raw = dataclasses.asdict(batch)
del raw["__subflavors__"]
options_text = f"{options_text}Answer with the option's letter from the given choices directly." return raw
return "{}\n{}".format(question, options_text) def select_samples_to_pack(self, samples: List[ImageTaskSample]) -> List[List[ImageTaskSample]]:
"""Selects which samples will be packed together.
def format_multichoice_answer(idx): NOTE: Energon dataloader calls this method internally if packing is used.
"""Format multi-choice answer.""" Please see https://nvidia.github.io/Megatron-Energon/packing.html
return chr(ord('A') + idx) """
lengths = [sample.total_len for sample in samples]
packed_samples = greedy_knapsack(lengths, samples, self.packing_seq_length)
return packed_samples
@stateless
def pack_selected_samples(self, samples: List[ImageTaskSample]) -> List[ImageTaskSamplePacked]:
"""
Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked.
NOTE: Energon dataloader calls this method internally if packing is used.
Please see https://nvidia.github.io/Megatron-Energon/packing.html
Args:
samples: List of ImageTaskSample instances to pack into one sample.
Returns:
ImageTaskSamplePacked instance.
"""
packing_seq_len = self.packing_seq_length
packed_tokens = []
packed_labels = []
packed_imgs = []
current_length = 0
max_length = 0
cu_lengths = [0]
# Process each sample and build lists that we will concatenate to create the packed sample.
for _, sample in enumerate(samples):
sample_len = sample.total_len
if sample_len > max_length:
max_length = sample_len
# If adding this sample exceeds the max length, stop.
# This should not happen. The select_samples_to_pack method should have already ensured that the samples fit.
if current_length + sample_len > packing_seq_len:
raise ValueError(f"Packed sample exceeds the maximum sequence length of {packing_seq_len}: {samples}")
# Add the sample's tokens and labels
packed_tokens.append(sample.tokens)
packed_labels.append(sample.labels)
# Add the images
packed_imgs += sample.imgs
current_length += sample_len
cu_lengths.append(current_length)
# Concatenate packed tokens and labels.
packed_tokens = torch.cat(packed_tokens, dim=0)
packed_labels = torch.cat(packed_labels, dim=0)
return ImageTaskSamplePacked(
__key__=",".join([s.__key__ for s in samples]),
__restore_key__=(), # Will be set by energon based on `samples`
__subflavor__=None,
__subflavors__=samples[0].__subflavors__,
tokens=packed_tokens,
labels=packed_labels,
imgs=packed_imgs,
cu_lengths=torch.tensor(cu_lengths, dtype=torch.int32),
max_length=max_length,
num_tiles=[n for s in samples for n in s.num_tiles],
)
def print_error_handler(exc: Exception, key: Optional[str]):
print(
f"The following exception occurred in the dataloader for sample {key} and is skipped",
file=sys.stderr,
)
traceback.print_exc()
def format_multichoice_question(question, multichoice_options):
"""Format multi-choice question."""
options_text = ["{}. {}\n".format(chr(ord('A') + i), option) for i, option in
zip(range(len(multichoice_options)), multichoice_options)]
options_text = "".join(options_text)
options_text = f"{options_text}Answer with the option's letter from the given choices directly."
return "{}\n{}".format(question, options_text)
def format_multichoice_answer(idx):
"""Format multi-choice answer."""
return chr(ord('A') + idx)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
import warnings
from dataclasses import dataclass
from typing import Any, List
from megatron.energon import Sample
from megatron.energon.epathlib.epath import EPath
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@dataclass
class SampleListSample(Sample):
"""Sample type for a list of samples of any type which needs to be packed together.
This is useful for datasets which are packed offline.
"""
#: The images of the sequence
samples: List[Any]
class SampleListWebdataset(DefaultDecoderWebdatasetFactory[SampleListSample]):
__sample_type__ = SampleListSample
def __init__(self, path: EPath, **kwargs):
warnings.warn(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content",
DeprecationWarning,
)
super().__init__(path, **kwargs)
@dataclass
class OfflineTargetAspectRatioSample(Sample):
"""Sample type for image + text samples with target aspect ratio computed offline."""
#: The images of the sequence
images: List[torch.Tensor]
#: The texts of the sequence
texts: List[str]
target_aspect_ratio: List[List]
import argparse
import json
from evaluate_vqav2 import compute_vqa_accuracy
from evaluate_mmmu import get_input_output_paths
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="InfoVQA")
results = []
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(
{
"question_id": res["sample_id"],
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
)
# Make order deterministic.
# results = sorted(results, key=lambda d: d["question_id"])
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
return output_file_path
def infovqa_eval(input_path):
"""Run InfoVQA evaluation."""
result_file_path = merge_input_files(input_path)
return compute_vqa_accuracy(result_file_path, task="InfoVQA")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()
avg_acc = infovqa_eval(args.input_path)
print(f"===== InfoVQA Accuracy {avg_acc:.2f}% =====")
import argparse
import json
from evaluate_vqav2 import compute_vqa_accuracy
from evaluate_mmmu import get_input_output_paths
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="SPDocVQA")
results = []
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(
{
"question_id": res["sample_id"],
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
)
# Make order deterministic.
# results = sorted(results, key=lambda d: d["question_id"])
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
return output_file_path
def spdocvqa_eval(input_path):
"""Run SPDocVQA evaluation."""
result_file_path = merge_input_files(input_path)
return compute_vqa_accuracy(result_file_path, task="SPDocVQA")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()
avg_acc = spdocvqa_eval(args.input_path)
print(f"===== SPDocVQA Accuracy {avg_acc:.2f}% =====")
import argparse import argparse
import json import json
from evaluate_mmmu import get_input_output_paths from evaluate_mmmu import get_input_output_paths
from open_flamingo.eval.vqa_metric import VQAEval from open_flamingo.eval.vqa_metric import VQAEval
# ANLS score calculation based on https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/dist.py#L1
def merge_input_files(input_path): # and https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/score.py#L6
"""Merge input files to a format compatible with the evaluator.""" # MIT License. Copyright (c) 2022 Shunsuke KITADA
input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2") def levenshtein_distance(s1: str, s2: str) -> int:
results = dict() if len(s1) > len(s2):
s1, s2 = s2, s1
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file: distances = list(range(len(s1) + 1))
for line in input_file: for i2, c2 in enumerate(s2):
res = json.loads(line) dists = [i2 + 1]
sample_id = res["sample_id"] for i1, c1 in enumerate(s1):
if c1 == c2:
# Skip possible duplicates. dists.append(distances[i1])
if sample_id in results: else:
continue dists.append(1 + min((distances[i1], distances[i1 + 1], dists[-1])))
distances = dists
res["question_id"] = sample_id
results[sample_id] = res return distances[-1]
results = list(results.values())
def normalized_levenshtein_distance(s1: str, s2: str) -> float:
with open(output_file_path, "w") as output_file: dist = levenshtein_distance(s1, s2)
json.dump(results, output_file) length = max(len(s1.upper()), len(s2.upper()))
return 0.0 if length == 0 else dist / length
return output_file_path
def similarity_function(prediction: str, gold_label: str, threshold: float) -> float:
nl_score = normalized_levenshtein_distance(prediction, gold_label)
def is_number(n: str): return 1 - nl_score if nl_score < threshold else 0.0
"""Check if input is a number."""
try: def anls_score(
float(n) prediction: str, gold_labels: List[str], threshold: float = 0.5
return True ) -> float:
except ValueError:
return False # not case sensitive, but space sensitive
y_pred = " ".join(prediction.strip().lower().split())
def compute_vqa_accuracy(result_file, task): anls_scores: List[float] = []
"""Compute VQA accuracy.""" for gold_label in gold_labels:
merged_results = json.load(open(result_file))
# not case sensitive, but space sensitive
vqa = VQAEval(vqa=None, vqaRes=None) y_true = " ".join(gold_label.strip().lower().split())
all_acc = []
for res in merged_results: anls_score = similarity_function(y_pred, y_true, threshold)
pred = res["answer"] anls_scores.append(anls_score)
pred = vqa.processPunctuation(pred)
pred = vqa.processDigitArticle(pred) score = max(anls_scores)
gt = res["gt_answer"] return score
gt = [vqa.processPunctuation(ans) for ans in gt]
gt = [vqa.processDigitArticle(ans) for ans in gt] def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
# ChartQA uses relaxed accuracy: input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2")
# "We consider an answer to be correct if it is within 5% of the gold answer.
# For non-numeric answers, we still need an exact match to consider an answer to be correct." results = dict()
if task == "ChartQA":
acc = 0.0 for input_file_path in input_file_paths:
assert len(gt) == 1, "expected exactly one groundtruth answer." with open(input_file_path, "r") as input_file:
gt = gt[0] for line in input_file:
res = json.loads(line)
pred = pred.rstrip("%") sample_id = res["sample_id"]
gt = gt.rstrip("%")
# Skip possible duplicates.
if is_number(pred) and is_number(gt): if sample_id in results:
pred = float(pred) continue
gt = float(gt)
if pred >= (gt * 0.95) and pred <= (gt * 1.05): res["question_id"] = sample_id
acc = 1.0 results[sample_id] = res
elif pred == gt:
acc = 1.0 results = list(results.values())
all_acc.append(acc) with open(output_file_path, "w") as output_file:
elif task in ("VQAv2", "TextVQA"): json.dump(results, output_file)
num_match = sum([pred == ans for ans in gt])
acc = min(1.0, num_match / 3.0) return output_file_path
all_acc.append(acc)
elif task == "AI2D":
assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}" def is_number(n: str):
acc = pred == gt[0] """Check if input is a number."""
all_acc.append(acc) try:
else: float(n)
raise NotImplementedError(f"unknown task {task}") return True
except ValueError:
acc_avg = sum(all_acc) / len(all_acc) * 100 return False
return acc_avg
def compute_vqa_accuracy(result_file, task):
"""Compute VQA accuracy."""
def vqav2_eval(input_path): merged_results = json.load(open(result_file))
"""Run VQAv2 evaluation."""
result_file = merge_input_files(input_path) vqa = VQAEval(vqa=None, vqaRes=None)
avg_acc = compute_vqa_accuracy(result_file, task="VQAv2") all_acc = []
return avg_acc for res in merged_results:
pred = res["answer"]
pred = vqa.processPunctuation(pred)
if __name__ == "__main__": pred = vqa.processDigitArticle(pred)
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)") gt = res["gt_answer"]
args = parser.parse_args() gt = [vqa.processPunctuation(ans) for ans in gt]
gt = [vqa.processDigitArticle(ans) for ans in gt]
avg_acc = vqav2_eval(args.input_path)
# ChartQA uses relaxed accuracy:
print(f"===== VQAv2 Accuracy {avg_acc:.2f}% =====") # "We consider an answer to be correct if it is within 5% of the gold answer.
# For non-numeric answers, we still need an exact match to consider an answer to be correct."
if task == "ChartQA":
acc = 0.0
assert len(gt) == 1, "expected exactly one groundtruth answer."
gt = gt[0]
pred = pred.rstrip("%")
gt = gt.rstrip("%")
if is_number(pred) and is_number(gt):
pred = float(pred)
gt = float(gt)
if pred >= (gt * 0.95) and pred <= (gt * 1.05):
acc = 1.0
elif pred == gt:
acc = 1.0
all_acc.append(acc)
elif task in ("VQAv2", "TextVQA"):
num_match = sum([pred == ans for ans in gt])
acc = min(1.0, num_match / 3.0)
all_acc.append(acc)
elif task in ("SPDocVQA", "InfoVQA"):
acc = anls_score(prediction=pred, gold_labels=gt, threshold=0.5)
all_acc.append(acc)
elif task == "AI2D":
assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}"
acc = pred == gt[0]
all_acc.append(acc)
else:
raise NotImplementedError(f"unknown task {task}")
acc_avg = sum(all_acc) / len(all_acc) * 100
return acc_avg
def vqav2_eval(input_path):
"""Run VQAv2 evaluation."""
result_file = merge_input_files(input_path)
avg_acc = compute_vqa_accuracy(result_file, task="VQAv2")
return avg_acc
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()
avg_acc = vqav2_eval(args.input_path)
print(f"===== VQAv2 Accuracy {avg_acc:.2f}% =====")
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Evaluation datasets.""" """Evaluation datasets."""
import glob import glob
import itertools import itertools
import json import json
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
from image_processing import get_visual_transform from image_processing import get_visual_transform
from PIL import Image from PIL import Image
from megatron.training import print_rank_0 from megatron.training import print_rank_0
def _get_partition_bounds( def _get_partition_bounds(
total_num_samples, num_samples_per_partition, num_partitions, partition_id total_num_samples, num_samples_per_partition, num_partitions, partition_id
): ):
if num_samples_per_partition == 0: if num_samples_per_partition == 0:
samples_per_partition = [ samples_per_partition = [
int(x) for x in np.linspace(0, total_num_samples, num_partitions + 1) int(x) for x in np.linspace(0, total_num_samples, num_partitions + 1)
] ]
return samples_per_partition[partition_id], samples_per_partition[partition_id + 1] return samples_per_partition[partition_id], samples_per_partition[partition_id + 1]
return num_samples_per_partition * partition_id, num_samples_per_partition * (partition_id + 1) return num_samples_per_partition * partition_id, num_samples_per_partition * (partition_id + 1)
class VQADataset(torch.utils.data.Dataset): class VQADataset(torch.utils.data.Dataset):
"""VQA evaluation dataset.""" """VQA evaluation dataset."""
def __init__( def __init__(
self, self,
input_image_path, input_image_path,
gt_path, gt_path,
num_samples_per_partition, num_samples_per_partition,
num_partitions, num_partitions,
partition_id, partition_id,
keys, keys,
img_h, img_h,
img_w, img_w,
use_tiling, use_tiling,
max_num_tiles, max_num_tiles,
use_thumbnail, use_thumbnail,
vision_model_type, vision_model_type,
): ):
samples = json.load(open(gt_path, encoding='utf-8')) samples = json.load(open(gt_path, encoding='utf-8'))
if "data" in samples: if "data" in samples:
samples = samples["data"] samples = samples["data"]
# Optionally, process only a subset of the input files. # Optionally, process only a subset of the input files.
if num_partitions > 0: if num_partitions > 0:
lb, ub = _get_partition_bounds( lb, ub = _get_partition_bounds(
len(samples), num_samples_per_partition, num_partitions, partition_id len(samples), num_samples_per_partition, num_partitions, partition_id
) )
samples = samples[lb:ub] samples = samples[lb:ub]
self._keys = keys self._keys = keys
self._samples = samples self._samples = samples
self._input_image_path = input_image_path self._input_image_path = input_image_path
self._img_h = img_h self._img_h = img_h
self._img_w = img_w self._img_w = img_w
self._use_tiling = use_tiling self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail self._use_thumbnail = use_thumbnail
self._vision_model_type = vision_model_type self._vision_model_type = vision_model_type
def __len__(self): def __len__(self):
return len(self._samples) return len(self._samples)
def __getitem__(self, idx): def __getitem__(self, idx):
sample = self._samples[idx] sample = self._samples[idx]
img_file = "{}/{}".format(self._input_image_path, sample[self._keys["image_id"]]) img_file = "{}/{}".format(self._input_image_path, sample[self._keys["image_id"]])
if not os.path.exists(img_file): if not os.path.exists(img_file):
img_file += ".jpg" img_file += ".jpg"
if not os.path.exists(img_file): if not os.path.exists(img_file):
img_file = img_file.replace('.jpg', '.png') img_file = img_file.replace('.jpg', '.png')
img = Image.open(img_file) img = Image.open(img_file)
imgs = get_visual_transform( imgs = get_visual_transform(
img, img,
self._img_h, self._img_h,
self._img_w, self._img_w,
self._use_tiling, self._use_tiling,
self._max_num_tiles, self._max_num_tiles,
self._use_thumbnail, self._use_thumbnail,
augment=False, augment=False,
vision_model_type=self._vision_model_type, vision_model_type=self._vision_model_type,
) )
tile_count = torch.tensor([len(imgs)], dtype=torch.int) tile_count = torch.tensor([len(imgs)], dtype=torch.int)
sample_id = idx sample_id = idx
if "sample_id" in self._keys: if "sample_id" in self._keys:
sample_id = sample[self._keys["sample_id"]] sample_id = sample[self._keys["sample_id"]]
metadata = "" # Not used. metadata = "" # Not used.
return ( return (
torch.stack(imgs), torch.stack(imgs),
tile_count, tile_count,
sample_id, sample_id,
sample[self._keys["question"]], sample[self._keys["question"]],
sample[self._keys["answer"]], sample[self._keys["answer"]],
metadata, metadata,
) )
class CaptioningDataset(torch.utils.data.Dataset): class CaptioningDataset(torch.utils.data.Dataset):
"""Captioning evaluation dataset.""" """Captioning evaluation dataset."""
def __init__( def __init__(
self, self,
input_image_path, input_image_path,
gt_path, gt_path,
num_samples_per_partition, num_samples_per_partition,
num_partitions, num_partitions,
partition_id, partition_id,
img_h, img_h,
img_w, img_w,
use_tiling, use_tiling,
max_num_tiles, max_num_tiles,
use_thumbnail, use_thumbnail,
vision_model_type, vision_model_type,
): ):
image_files = sorted(glob.glob(input_image_path + "/*")) image_files = sorted(glob.glob(input_image_path + "/*"))
# Optionally, process only a subset of the input files. # Optionally, process only a subset of the input files.
if num_partitions > 0: if num_partitions > 0:
lb, ub = _get_partition_bounds( lb, ub = _get_partition_bounds(
len(image_files), num_samples_per_partition, num_partitions, partition_id len(image_files), num_samples_per_partition, num_partitions, partition_id
) )
image_files = image_files[lb:ub] image_files = image_files[lb:ub]
gts = json.load(open(gt_path)) gts = json.load(open(gt_path))
answers = defaultdict(list) answers = defaultdict(list)
for gt in gts["annotations"]: for gt in gts["annotations"]:
answers[gt["image_id"]].append(gt['caption']) answers[gt["image_id"]].append(gt['caption'])
self._image_files = image_files self._image_files = image_files
self._answers = answers self._answers = answers
self._img_h = img_h self._img_h = img_h
self._img_w = img_w self._img_w = img_w
self._use_tiling = use_tiling self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail self._use_thumbnail = use_thumbnail
self._vision_model_type = vision_model_type self._vision_model_type = vision_model_type
def __len__(self): def __len__(self):
return len(self._image_files) return len(self._image_files)
def __getitem__(self, idx): def __getitem__(self, idx):
img_file = self._image_files[idx] img_file = self._image_files[idx]
image_id = int(img_file.split("_")[-1].split(".")[0]) image_id = int(img_file.split("_")[-1].split(".")[0])
img = Image.open(img_file) img = Image.open(img_file)
imgs = get_visual_transform( imgs = get_visual_transform(
img, img,
self._img_h, self._img_h,
self._img_w, self._img_w,
self._use_tiling, self._use_tiling,
self._max_num_tiles, self._max_num_tiles,
self._use_thumbnail, self._use_thumbnail,
augment=False, augment=False,
vision_model_type=self._vision_model_type, vision_model_type=self._vision_model_type,
) )
tile_count = torch.tensor([len(imgs)], dtype=torch.int) tile_count = torch.tensor([len(imgs)], dtype=torch.int)
question = "" # Fixed for all samples. question = "" # Fixed for all samples.
metadata = "" # Not used. metadata = "" # Not used.
return torch.stack(imgs), tile_count, image_id, question, self._answers[image_id], metadata return torch.stack(imgs), tile_count, image_id, question, self._answers[image_id], metadata
class MMMUDataset(torch.utils.data.Dataset): class MMMUDataset(torch.utils.data.Dataset):
"""MMMU evaluation dataset.""" """MMMU evaluation dataset."""
def __init__( def __init__(
self, self,
input_image_path, input_image_path,
num_samples_per_partition, num_samples_per_partition,
num_partitions, num_partitions,
partition_id, partition_id,
img_h, img_h,
img_w, img_w,
use_tiling, use_tiling,
max_num_tiles, max_num_tiles,
use_thumbnail, use_thumbnail,
prompt_style, prompt_style,
vision_model_type, vision_model_type,
): ):
import datasets import datasets
from MMMU.mmmu.utils.data_utils import CAT_SHORT2LONG, load_yaml from MMMU.mmmu.utils.data_utils import CAT_SHORT2LONG, load_yaml
# The following downloads the MMMU dataset from HuggingFace and uses the API from the MMMU github repo to run MMMU evaluation. # The following downloads the MMMU dataset from HuggingFace and uses the API from the MMMU github repo to run MMMU evaluation.
all_mmmu_datasets = [] all_mmmu_datasets = []
hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] hf_datasets_cache = os.environ["HF_DATASETS_CACHE"]
assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE."
for subject in CAT_SHORT2LONG.values(): for subject in CAT_SHORT2LONG.values():
# Use a local copy of the dataset if exists (can be faster) or the HF one. # Use a local copy of the dataset if exists (can be faster) or the HF one.
if os.path.exists(input_image_path): if os.path.exists(input_image_path):
subject_dataset = datasets.load_dataset( subject_dataset = datasets.load_dataset(
os.path.join(input_image_path, subject), os.path.join(input_image_path, subject),
split=datasets.Split.VALIDATION, split=datasets.Split.VALIDATION,
cache_dir=hf_datasets_cache, cache_dir=hf_datasets_cache,
verification_mode="no_checks", verification_mode="no_checks",
) )
else: else:
subject_dataset = datasets.load_dataset( subject_dataset = datasets.load_dataset(
"MMMU/MMMU", "MMMU/MMMU",
subject, subject,
split=datasets.Split.VALIDATION, split=datasets.Split.VALIDATION,
cache_dir=hf_datasets_cache, cache_dir=hf_datasets_cache,
) )
all_mmmu_datasets.append(subject_dataset) all_mmmu_datasets.append(subject_dataset)
dataset = datasets.concatenate_datasets(all_mmmu_datasets) dataset = datasets.concatenate_datasets(all_mmmu_datasets)
dataset = [s for s in dataset if s['id'].startswith("val")] dataset = [s for s in dataset if s['id'].startswith("val")]
# Optionally, process only a subset of the input files. # Optionally, process only a subset of the input files.
if num_partitions > 0: if num_partitions > 0:
lb, ub = _get_partition_bounds( lb, ub = _get_partition_bounds(
len(dataset), num_samples_per_partition, num_partitions, partition_id len(dataset), num_samples_per_partition, num_partitions, partition_id
) )
dataset = dataset[lb:ub] dataset = dataset[lb:ub]
# Using the LLaVA config from the MMMU repo. # Using the LLaVA config from the MMMU repo.
config = load_yaml("examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml") config = load_yaml("examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml")
for k, v in config.items(): for k, v in config.items():
if isinstance(v, list): if isinstance(v, list):
assert len(v) == 1, "only one value supported." assert len(v) == 1, "only one value supported."
config[k] = v[0] config[k] = v[0]
self._config = config self._config = config
self._dataset = dataset self._dataset = dataset
self._img_h = img_h self._img_h = img_h
self._img_w = img_w self._img_w = img_w
self._use_tiling = use_tiling self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail self._use_thumbnail = use_thumbnail
self._prompt_style = prompt_style self._prompt_style = prompt_style
self._vision_model_type = vision_model_type self._vision_model_type = vision_model_type
def __len__(self): def __len__(self):
return len(self._dataset) return len(self._dataset)
def __getitem__(self, idx): def __getitem__(self, idx):
from MMMU.mmmu.utils.data_utils import construct_prompt, process_single_sample from MMMU.mmmu.utils.data_utils import construct_prompt, process_single_sample
sample = self._dataset[idx] sample = self._dataset[idx]
# Use the single image approach from the MMMU repo. # Use the single image approach from the MMMU repo.
if self._prompt_style == "single_image": if self._prompt_style == "single_image":
sample = process_single_sample(sample) sample = process_single_sample(sample)
sample = construct_prompt(sample, self._config) sample = construct_prompt(sample, self._config)
img = sample["image"] img = sample["image"]
sample_imgs = get_visual_transform( sample_imgs = get_visual_transform(
img, img,
self._img_h, self._img_h,
self._img_w, self._img_w,
self._use_tiling, self._use_tiling,
self._max_num_tiles, self._max_num_tiles,
self._use_thumbnail, self._use_thumbnail,
augment=False, augment=False,
vision_model_type=self._vision_model_type, vision_model_type=self._vision_model_type,
) )
sample_num_tiles = [len(sample_imgs)] sample_num_tiles = [len(sample_imgs)]
prompt = sample["final_input_prompt"] prompt = sample["final_input_prompt"]
for i in range(8): for i in range(8):
prompt = prompt.replace(f"<image {i}>", "") prompt = prompt.replace(f"<image {i}>", "")
sample["final_input_prompt"] = f"<image>\n{prompt}" sample["final_input_prompt"] = f"<image>\n{prompt}"
elif self._prompt_style == "vlmevalkit": elif self._prompt_style == "vlmevalkit":
sample = construct_prompt(sample, self._config) sample = construct_prompt(sample, self._config)
if sample["question_type"] == "multiple-choice": if sample["question_type"] == "multiple-choice":
question = sample["question"] question = sample["question"]
options = "" options = ""
for k, v in sample["index2ans"].items(): for k, v in sample["index2ans"].items():
options += f"{k}. {v}\n" options += f"{k}. {v}\n"
final_prompt = f"{question}\n" final_prompt = f"{question}\n"
if "hint" in sample: if "hint" in sample:
final_prompt += f"Hint: {sample['hint']}\n" final_prompt += f"Hint: {sample['hint']}\n"
if "task_instructions" in sample: if "task_instructions" in sample:
final_prompt += f"Task instructions: {sample['task_instructions']}\n" final_prompt += f"Task instructions: {sample['task_instructions']}\n"
final_prompt += options final_prompt += options
final_prompt += "Answer with the option's letter from the given choices directly." final_prompt += "Answer with the option's letter from the given choices directly."
sample["final_input_prompt"] = final_prompt.rstrip() sample["final_input_prompt"] = final_prompt.rstrip()
else: else:
question = sample["question"] question = sample["question"]
final_prompt = f"{question}\n" final_prompt = f"{question}\n"
final_prompt += "Answer the question directly." final_prompt += "Answer the question directly."
sample["final_input_prompt"] = final_prompt.rstrip() sample["final_input_prompt"] = final_prompt.rstrip()
sample_imgs = [] sample_imgs = []
sample_num_tiles = [] sample_num_tiles = []
img_indices = sorted(list(set(re.findall(r"<image (\d+)", sample["final_input_prompt"])))) img_indices = sorted(list(set(re.findall(r"<image (\d+)", sample["final_input_prompt"]))))
# If there are multiple input images, we need to avoid the number of image embeddings getting too large. # If there are multiple input images, we need to avoid the number of image embeddings getting too large.
adjusted_max_num_tiles = max(1, self._max_num_tiles // len(img_indices)) adjusted_max_num_tiles = max(1, self._max_num_tiles // len(img_indices))
adjusted_max_num_tiles = min(adjusted_max_num_tiles, self._max_num_tiles) adjusted_max_num_tiles = min(adjusted_max_num_tiles, self._max_num_tiles)
for img_idx in img_indices: for img_idx in img_indices:
img_key = f"image_{img_idx}" img_key = f"image_{img_idx}"
img_str = f"<image {img_idx}>" img_str = f"<image {img_idx}>"
img = sample[img_key] img = sample[img_key]
assert img is not None, f"{img_str} is in prompt but not in sample images" assert img is not None, f"{img_str} is in prompt but not in sample images"
imgs = get_visual_transform( imgs = get_visual_transform(
img, img,
self._img_h, self._img_h,
self._img_w, self._img_w,
self._use_tiling, self._use_tiling,
adjusted_max_num_tiles, adjusted_max_num_tiles,
self._use_thumbnail, self._use_thumbnail,
augment=False, augment=False,
vision_model_type=self._vision_model_type, vision_model_type=self._vision_model_type,
) # List of tiles. ) # List of tiles.
sample_imgs.extend(imgs) sample_imgs.extend(imgs)
sample_num_tiles.append(len(imgs)) sample_num_tiles.append(len(imgs))
sample["final_input_prompt"] = " ".join([f'<image {i + 1}><image>' for i in range(len(img_indices))]) + "\n" + sample["final_input_prompt"] sample["final_input_prompt"] = " ".join([f'<image {i + 1}><image>' for i in range(len(img_indices))]) + "\n" + sample["final_input_prompt"]
elif self._prompt_style == "multi_image": elif self._prompt_style == "multi_image":
sample = construct_prompt(sample, self._config) sample = construct_prompt(sample, self._config)
sample_imgs = [] sample_imgs = []
sample_num_tiles = [] sample_num_tiles = []
img_indices = re.findall(r"<image (\d+)", sample["final_input_prompt"]) img_indices = re.findall(r"<image (\d+)", sample["final_input_prompt"])
# If there are multiple input images, we need to avoid the number of image embeddings getting too large. # If there are multiple input images, we need to avoid the number of image embeddings getting too large.
adjusted_max_num_tiles = max(1, self._max_num_tiles // len(img_indices)) adjusted_max_num_tiles = max(1, self._max_num_tiles // len(img_indices))
for img_idx in img_indices: for img_idx in img_indices:
img_key = f"image_{img_idx}" img_key = f"image_{img_idx}"
img_str = f"<image {img_idx}>" img_str = f"<image {img_idx}>"
img = sample[img_key] img = sample[img_key]
assert img is not None, f"{img_str} is in prompt but not in sample images" assert img is not None, f"{img_str} is in prompt but not in sample images"
# Note: Only replace the current image tag. # Note: Only replace the current image tag.
sample["final_input_prompt"] = sample["final_input_prompt"].replace( sample["final_input_prompt"] = sample["final_input_prompt"].replace(
img_str, "<image>", 1 img_str, "<image>", 1
) )
imgs = get_visual_transform( imgs = get_visual_transform(
img, img,
self._img_h, self._img_h,
self._img_w, self._img_w,
self._use_tiling, self._use_tiling,
adjusted_max_num_tiles, adjusted_max_num_tiles,
self._use_thumbnail, self._use_thumbnail,
augment=False, augment=False,
vision_model_type=self._vision_model_type, vision_model_type=self._vision_model_type,
) # List of tiles. ) # List of tiles.
sample_imgs.extend(imgs) sample_imgs.extend(imgs)
sample_num_tiles.append(len(imgs)) sample_num_tiles.append(len(imgs))
# Sanity check. # Sanity check.
for i in range(1, 8): for i in range(1, 8):
assert ( assert (
f"<image {i}>" not in sample["final_input_prompt"] f"<image {i}>" not in sample["final_input_prompt"]
), "prompt contains unhandled image tags" ), "prompt contains unhandled image tags"
else: else:
raise ValueError(f"unknown prompt style {self._prompt_style}") raise ValueError(f"unknown prompt style {self._prompt_style}")
# MMMU specific metadata. # MMMU specific metadata.
metadata = {"question_type": sample["question_type"]} metadata = {"question_type": sample["question_type"]}
if sample["question_type"] == "multiple-choice": if sample["question_type"] == "multiple-choice":
metadata["index2ans"] = sample["index2ans"] metadata["index2ans"] = sample["index2ans"]
metadata["all_choices"] = sample["all_choices"] metadata["all_choices"] = sample["all_choices"]
prompt = sample['final_input_prompt'] prompt = sample['final_input_prompt']
tile_count = torch.tensor(sample_num_tiles, dtype=torch.int) tile_count = torch.tensor(sample_num_tiles, dtype=torch.int)
return ( return (
torch.stack(sample_imgs), torch.stack(sample_imgs),
tile_count, tile_count,
sample["id"], sample["id"],
prompt, prompt,
sample["answer"], sample["answer"],
metadata, metadata,
) )
class VideoMMMEDataset(torch.utils.data.Dataset): class VideoMMEDataset(torch.utils.data.Dataset):
"Video MME evaluation dataset." "Video MME evaluation dataset."
def __init__( def __init__(
self, self,
input_image_path, input_image_path,
gt_path, gt_path,
num_samples_per_partition, num_samples_per_partition,
num_partitions, num_partitions,
partition_id, partition_id,
img_h, img_h,
img_w, img_w,
use_tiling, use_tiling,
max_num_tiles, max_num_tiles,
use_thumbnail, use_thumbnail,
num_frames, num_frames,
vision_model_type, vision_model_type,
): ):
ground_truth_original = json.load(open(gt_path)) ground_truth_original = json.load(open(gt_path))
ground_truth = [] ground_truth = []
for gt in ground_truth_original: for gt in ground_truth_original:
video_path = gt["url"] video_path = gt["url"]
video_path = video_path.replace("https://www.youtube.com/watch?v=", "") video_path = video_path.replace("https://www.youtube.com/watch?v=", "")
video_path = video_path.replace("https://m.youtube.com/watch?v=", "") video_path = video_path.replace("https://m.youtube.com/watch?v=", "")
video_path = os.path.join(input_image_path, video_path + ".mp4") video_path = os.path.join(input_image_path, video_path + ".mp4")
if not os.path.exists(video_path): if not os.path.exists(video_path):
continue continue
gt["video_path"] = video_path gt["video_path"] = video_path
ground_truth.append(gt) ground_truth.append(gt)
ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"]) ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"])
print_rank_0(f"Found {len(ground_truth)} videos to process.") print_rank_0(f"Found {len(ground_truth)} videos to process.")
if num_partitions > 0: if num_partitions > 0:
start_idx, end_idx = _get_partition_bounds( start_idx, end_idx = _get_partition_bounds(
len(ground_truth), num_samples_per_partition, num_partitions, partition_id len(ground_truth), num_samples_per_partition, num_partitions, partition_id
) )
ground_truth = ground_truth[start_idx:end_idx] ground_truth = ground_truth[start_idx:end_idx]
self._ground_truth = ground_truth self._ground_truth = ground_truth
self._img_h = img_h self._img_h = img_h
self._img_w = img_w self._img_w = img_w
self._use_tiling = use_tiling self._use_tiling = False
self._max_num_tiles = max_num_tiles self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail self._use_thumbnail = use_thumbnail
self._num_frames = num_frames self._num_frames = num_frames
self._vision_model_type = vision_model_type self._vision_model_type = vision_model_type
def __len__(self): def __len__(self):
return len(self._ground_truth) return len(self._ground_truth)
def __getitem__(self, idx): def __getitem__(self, idx):
from torchvision.io import read_video from torchvision.io import read_video
gt = self._ground_truth[idx] gt = self._ground_truth[idx]
video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec') video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec')
video = video.numpy() video = video.numpy()
selected_frames = torch.linspace(0, video.shape[0] - 1, self._num_frames).long() selected_frames = torch.linspace(0, video.shape[0] - 1, self._num_frames).long()
video_frames = video[selected_frames] video_frames = video[selected_frames]
if self._num_frames == 1: if self._num_frames == 1:
video_frames = video_frames[None] video_frames = video_frames[None]
imgs = list( imgs = []
itertools.chain.from_iterable( for img in video_frames:
get_visual_transform( from torchvision.transforms import ToPILImage
img, to_pil = ToPILImage()
self._img_h, img = to_pil(img)
self._img_w, imgs += get_visual_transform(
self._use_tiling, img, self._img_h, self._img_w, self._use_tiling, self._max_num_tiles,
self._max_num_tiles, self._use_thumbnail, augment=False, vision_model_type=self._vision_model_type
self._use_thumbnail, )
augment=False,
vision_model_type=self._vision_model_type, for question in gt["questions"]:
) # Very hacky, but we essentially re-create gt holding only the
for img in video_frames # question of interest. This is the make this generation script
) # compatible with the Video MME evaluation script.
) question_dict = {
"video_id": gt["video_id"],
for question in gt["questions"]: "duration_category": gt["duration_category"],
# Very hacky, but we essentially re-create gt holding only the "video_category": gt["video_category"],
# question of interest. This is the make this generation script "video_subcategory": gt["video_subcategory"],
# compatible with the Video MME evaluation script. "url": gt["url"],
question_dict = { "questions": [question],
"video_id": gt["video_id"], }
"duration_category": gt["duration_category"],
"video_category": gt["video_category"], num_tiles = torch.tensor([len(imgs)], dtype=torch.int)
"video_subcategory": gt["video_subcategory"],
"url": gt["url"], answer = ""
"questions": [question], metadata = ""
}
return (
num_tiles = torch.tensor([len(imgs)], dtype=torch.int) torch.stack(imgs),
num_tiles,
answer = "" question["question_id"],
metadata = "" question_dict,
answer,
return ( metadata,
torch.stack(imgs), )
num_tiles,
question["question_id"],
question_dict, class OCRBenchDataset(torch.utils.data.Dataset):
answer, """OCRBench evaluation dataset."""
metadata,
) def __init__(
self,
input_image_path,
class OCRBenchDataset(torch.utils.data.Dataset): gt_path,
"""OCRBench evaluation dataset.""" num_samples_per_partition,
num_partitions,
def __init__( partition_id,
self, img_h,
input_image_path, img_w,
gt_path, use_tiling,
num_samples_per_partition, max_num_tiles,
num_partitions, use_thumbnail,
partition_id, vision_model_type,
img_h, ):
img_w, gt = json.load(open(gt_path, encoding='utf-8'))
use_tiling,
max_num_tiles, if num_partitions > 0:
use_thumbnail, start_idx, end_idx = _get_partition_bounds(
vision_model_type, len(gt), num_samples_per_partition, num_partitions, partition_id
): )
gt = json.load(open(gt_path, encoding='utf-8')) gt = gt[start_idx:end_idx]
if num_partitions > 0: self._input_image_path = input_image_path
start_idx, end_idx = _get_partition_bounds( self._gt = gt
len(gt), num_samples_per_partition, num_partitions, partition_id self._img_h = img_h
) self._img_w = img_w
gt = gt[start_idx:end_idx] self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._input_image_path = input_image_path self._use_thumbnail = use_thumbnail
self._gt = gt self._vision_model_type = vision_model_type
self._img_h = img_h
self._img_w = img_w def __len__(self):
self._use_tiling = use_tiling return len(self._gt)
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail def __getitem__(self, idx):
self._vision_model_type = vision_model_type img_path = os.path.join(self._input_image_path, self._gt[idx]['image_path'])
def __len__(self): img = Image.open(img_path)
return len(self._gt) imgs = get_visual_transform(
img,
def __getitem__(self, idx): self._img_h,
img_path = os.path.join(self._input_image_path, self._gt[idx]['image_path']) self._img_w,
self._use_tiling,
img = Image.open(img_path) self._max_num_tiles,
imgs = get_visual_transform( self._use_thumbnail,
img, augment=False,
self._img_h, vision_model_type=self._vision_model_type,
self._img_w, )
self._use_tiling,
self._max_num_tiles, tile_count = torch.tensor([len(imgs)], dtype=torch.int)
self._use_thumbnail,
augment=False, metadata = {
vision_model_type=self._vision_model_type, "dataset_name": self._gt[idx]["dataset_name"],
) "data_type": self._gt[idx]["type"],
}
tile_count = torch.tensor([len(imgs)], dtype=torch.int)
return (
metadata = { torch.stack(imgs),
"dataset_name": self._gt[idx]["dataset_name"], tile_count,
"data_type": self._gt[idx]["type"], idx,
} self._gt[idx]["question"],
self._gt[idx]["answers"],
return ( metadata,
torch.stack(imgs), )
tile_count,
idx,
self._gt[idx]["question"], class MathVistaDataset(torch.utils.data.Dataset):
self._gt[idx]["answers"], """MathVista evaluation dataset."""
metadata,
) def __init__(
self,
input_image_path,
class MathVistaDataset(torch.utils.data.Dataset): num_samples_per_partition,
"""MathVista evaluation dataset.""" num_partitions,
partition_id,
def __init__( img_h,
self, img_w,
input_image_path, use_tiling,
num_samples_per_partition, max_num_tiles,
num_partitions, use_thumbnail,
partition_id, vision_model_type,
img_h, ):
img_w, import datasets
use_tiling,
max_num_tiles, hf_datasets_cache = os.environ["HF_DATASETS_CACHE"]
use_thumbnail, assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE."
vision_model_type,
): if os.path.exists(input_image_path):
import datasets dataset = datasets.load_dataset(
input_image_path, cache_dir=hf_datasets_cache, verification_mode="no_checks"
hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] )
assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." else:
dataset = datasets.load_dataset(
if os.path.exists(input_image_path): "AI4Math/MathVista", split="testmini", cache_dir=hf_datasets_cache
dataset = datasets.load_dataset( )
input_image_path, cache_dir=hf_datasets_cache, verification_mode="no_checks"
) if num_partitions > 0:
else: start_idx, end_idx = _get_partition_bounds(
dataset = datasets.load_dataset( len(dataset), num_samples_per_partition, num_partitions, partition_id
"AI4Math/MathVista", split="testmini", cache_dir=hf_datasets_cache )
) dataset = dataset[start_idx:end_idx]
if num_partitions > 0: self._dataset = dataset
start_idx, end_idx = _get_partition_bounds( self._img_h = img_h
len(dataset), num_samples_per_partition, num_partitions, partition_id self._img_w = img_w
) self._use_tiling = use_tiling
dataset = dataset[start_idx:end_idx] self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._dataset = dataset self._vision_model_type = vision_model_type
self._img_h = img_h
self._img_w = img_w def __len__(self):
self._use_tiling = use_tiling return len(self._dataset["pid"])
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail def __getitem__(self, idx):
self._vision_model_type = vision_model_type # Already a PIL object.
img = self._dataset['decoded_image'][idx]
def __len__(self):
return len(self._dataset["pid"]) imgs = get_visual_transform(
img,
def __getitem__(self, idx): self._img_h,
# Already a PIL object. self._img_w,
img = self._dataset['decoded_image'][idx] self._use_tiling,
self._max_num_tiles,
imgs = get_visual_transform( self._use_thumbnail,
img, augment=False,
self._img_h, vision_model_type=self._vision_model_type,
self._img_w, )
self._use_tiling,
self._max_num_tiles, tile_count = torch.tensor([len(imgs)], dtype=torch.int)
self._use_thumbnail,
augment=False, question_id = self._dataset["pid"][idx]
vision_model_type=self._vision_model_type, question = self._dataset["question"][idx]
) question_type = self._dataset["question_type"][idx] # free_form or multi_choice
query = self._dataset["query"][idx]
tile_count = torch.tensor([len(imgs)], dtype=torch.int) choices = self._dataset["choices"][idx]
answer = self._dataset["answer"][idx]
question_id = self._dataset["pid"][idx]
question = self._dataset["question"][idx] if question_type == 'multi_choice':
question_type = self._dataset["question_type"][idx] # free_form or multi_choice start_chr = 'A'
query = self._dataset["query"][idx] choices_str = ''
choices = self._dataset["choices"][idx] index2ans = {}
answer = self._dataset["answer"][idx] all_choices = []
for choice in choices:
if question_type == 'multi_choice': all_choices.append(start_chr)
start_chr = 'A' index2ans[start_chr] = choice
choices_str = '' choices_str += f"{start_chr}. {choice}\n"
index2ans = {} start_chr = chr(ord(start_chr) + 1)
all_choices = []
for choice in choices: question = question + '\n' + choices_str
all_choices.append(start_chr) question = question + "Answer with the option's letter from the given choices directly."
index2ans[start_chr] = choice answer = chr(ord('A') + choices.index(answer))
choices_str += f"{start_chr}. {choice}\n" else:
start_chr = chr(ord(start_chr) + 1) question = query.replace("Hint: ", "")
index2ans = {}
question = question + '\n' + choices_str all_choices = []
question = question + "Answer with the option's letter from the given choices directly."
answer = chr(ord('A') + choices.index(answer)) metadata = {
else: "question_type": question_type,
question = query.replace("Hint: ", "") "index2ans": index2ans,
index2ans = {} "all_choices": all_choices,
all_choices = [] }
metadata = { return torch.stack(imgs), tile_count, question_id, question, answer, metadata
"question_type": question_type,
"index2ans": index2ans,
"all_choices": all_choices, class AI2DDataset(torch.utils.data.Dataset):
} """AI2D evaluation dataset."""
return torch.stack(imgs), tile_count, question_id, question, answer, metadata def __init__(
self,
input_image_path,
class AI2DDataset(torch.utils.data.Dataset): gt_path,
"""AI2D evaluation dataset.""" num_samples_per_partition,
num_partitions,
def __init__( partition_id,
self, img_h,
input_image_path, img_w,
gt_path, use_tiling,
num_samples_per_partition, max_num_tiles,
num_partitions, use_thumbnail,
partition_id, no_mask,
img_h, vision_model_type,
img_w, ):
use_tiling, with open(gt_path, 'r') as f:
max_num_tiles, jsonl = list(f)
use_thumbnail,
no_mask, gt = [json.loads(json_str) for json_str in jsonl]
vision_model_type,
): if num_partitions > 0:
with open(gt_path, 'r') as f: start_idx, end_idx = _get_partition_bounds(
jsonl = list(f) len(gt), num_samples_per_partition, num_partitions, partition_id
)
gt = [json.loads(json_str) for json_str in jsonl] gt = gt[start_idx:end_idx]
if num_partitions > 0: self._gt = gt
start_idx, end_idx = _get_partition_bounds( self._input_image_path = input_image_path
len(gt), num_samples_per_partition, num_partitions, partition_id self._img_h = img_h
) self._img_w = img_w
gt = gt[start_idx:end_idx] self._use_tiling = use_tiling
self._max_num_tiles = max_num_tiles
self._gt = gt self._use_thumbnail = use_thumbnail
self._input_image_path = input_image_path self._no_mask = no_mask
self._img_h = img_h self._vision_model_type = vision_model_type
self._img_w = img_w
self._use_tiling = use_tiling def __len__(self):
self._max_num_tiles = max_num_tiles return len(self._gt)
self._use_thumbnail = use_thumbnail
self._no_mask = no_mask def __getitem__(self, idx):
self._vision_model_type = vision_model_type img_path = os.path.join(self._input_image_path, self._gt[idx]['image'])
if self._no_mask:
def __len__(self): img_path.replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES")
return len(self._gt)
img = Image.open(img_path)
def __getitem__(self, idx): imgs = get_visual_transform(
img_path = os.path.join(self._input_image_path, self._gt[idx]['image']) img,
if self._no_mask: self._img_h,
img_path.replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") self._img_w,
self._use_tiling,
img = Image.open(img_path) self._max_num_tiles,
imgs = get_visual_transform( self._use_thumbnail,
img, augment=False,
self._img_h, vision_model_type=self._vision_model_type,
self._img_w, )
self._use_tiling,
self._max_num_tiles, tile_count = torch.tensor([len(imgs)], dtype=torch.int)
self._use_thumbnail,
augment=False, metadata = "" # Not used.
vision_model_type=self._vision_model_type,
) return (
torch.stack(imgs),
tile_count = torch.tensor([len(imgs)], dtype=torch.int) tile_count,
self._gt[idx]["question_id"],
metadata = "" # Not used. self._gt[idx]["question"],
self._gt[idx]["answer"],
return ( metadata,
torch.stack(imgs), )
tile_count,
self._gt[idx]["question_id"],
self._gt[idx]["question"], def get_evaluation_dataset(
self._gt[idx]["answer"], task,
metadata, input_image_path,
) gt_path,
img_h,
img_w,
def get_evaluation_dataset( use_tiling,
task, max_num_tiles,
input_image_path, use_thumbnail,
gt_path, num_samples_per_partition,
img_h, num_partitions,
img_w, partition_id,
use_tiling, num_frames,
max_num_tiles, vision_model_type,
use_thumbnail, ):
num_samples_per_partition, """Get an evaluation dataset."""
num_partitions, if task == "TextVQA":
partition_id, keys = {
num_frames, "image_id": "image_id",
vision_model_type, "sample_id": "question_id",
): "question": "question",
"""Get an evaluation dataset.""" "answer": "answers",
if task == "TextVQA": }
keys = {
"image_id": "image_id", dataset = VQADataset(
"sample_id": "question_id", input_image_path,
"question": "question", gt_path,
"answer": "answers", num_samples_per_partition,
} num_partitions,
partition_id,
dataset = VQADataset( keys,
input_image_path, img_h,
gt_path, img_w,
num_samples_per_partition, use_tiling,
num_partitions, max_num_tiles,
partition_id, use_thumbnail,
keys, vision_model_type,
img_h, )
img_w, elif task == "VQAv2":
use_tiling, keys = {
max_num_tiles, "image_id": "image",
use_thumbnail, "sample_id": "question_id",
vision_model_type, "question": "question",
) "answer": "answer",
elif task == "VQAv2": }
keys = {
"image_id": "image", dataset = VQADataset(
"sample_id": "question_id", input_image_path,
"question": "question", gt_path,
"answer": "answer", num_samples_per_partition,
} num_partitions,
partition_id,
dataset = VQADataset( keys,
input_image_path, img_h,
gt_path, img_w,
num_samples_per_partition, use_tiling,
num_partitions, max_num_tiles,
partition_id, use_thumbnail,
keys, vision_model_type,
img_h, )
img_w, elif task == "ChartQA":
use_tiling, keys = {"image_id": "imgname", "question": "query", "answer": "label"}
max_num_tiles,
use_thumbnail, dataset = VQADataset(
vision_model_type, input_image_path,
) gt_path,
elif task == "ChartQA": num_samples_per_partition,
keys = {"image_id": "imgname", "question": "query", "answer": "label"} num_partitions,
partition_id,
dataset = VQADataset( keys,
input_image_path, img_h,
gt_path, img_w,
num_samples_per_partition, use_tiling,
num_partitions, max_num_tiles,
partition_id, use_thumbnail,
keys, vision_model_type,
img_h, )
img_w, elif task == "captioning":
use_tiling, dataset = CaptioningDataset(
max_num_tiles, input_image_path,
use_thumbnail, gt_path,
vision_model_type, num_samples_per_partition,
) num_partitions,
elif task == "captioning": partition_id,
dataset = CaptioningDataset( img_h,
input_image_path, img_w,
gt_path, use_tiling,
num_samples_per_partition, max_num_tiles,
num_partitions, use_thumbnail,
partition_id, vision_model_type,
img_h, )
img_w, elif task == 'MMMU':
use_tiling, # Note:
max_num_tiles, # - prompt_style="single_image" uses only one image like in the MMMU repo example.
use_thumbnail, # - prompt_style="multi_image" uses multiple input images.
vision_model_type, # - prompt_style="vlmevalkit" is similar to https://github.com/open-compass/VLMEvalKit/blob/5d3cebcf18ef4bfbadc3bd3ef80bdc7aad2c6557/vlmeval/vlm/internvl_chat.py#L499
) dataset = MMMUDataset(
elif task == 'MMMU': input_image_path,
# Note: num_samples_per_partition,
# - prompt_style="single_image" uses only one image like in the MMMU repo example. num_partitions,
# - prompt_style="multi_image" uses multiple input images. partition_id,
# - prompt_style="vlmevalkit" is similar to https://github.com/open-compass/VLMEvalKit/blob/5d3cebcf18ef4bfbadc3bd3ef80bdc7aad2c6557/vlmeval/vlm/internvl_chat.py#L499 img_h,
dataset = MMMUDataset( img_w,
input_image_path, use_tiling,
num_samples_per_partition, max_num_tiles,
num_partitions, use_thumbnail,
partition_id, prompt_style="single_image",
img_h, vision_model_type=vision_model_type,
img_w, )
use_tiling, elif task == "VideoMME":
max_num_tiles, dataset = VideoMMEDataset(
use_thumbnail, input_image_path,
prompt_style="single_image", gt_path,
vision_model_type=vision_model_type, num_samples_per_partition,
) num_partitions,
elif task == "VideoMME": partition_id,
dataset = VideoMMMEDataset( img_h,
input_image_path, img_w,
gt_path, use_tiling,
num_samples_per_partition, max_num_tiles,
num_partitions, use_thumbnail,
partition_id, num_frames,
img_h, vision_model_type,
img_w, )
use_tiling, elif task == "OCRBench":
max_num_tiles, dataset = OCRBenchDataset(
use_thumbnail, input_image_path,
num_frames, gt_path,
vision_model_type, num_samples_per_partition,
) num_partitions,
elif task == "OCRBench": partition_id,
dataset = OCRBenchDataset( img_h,
input_image_path, img_w,
gt_path, use_tiling,
num_samples_per_partition, max_num_tiles,
num_partitions, use_thumbnail,
partition_id, vision_model_type,
img_h, )
img_w, elif task == "MathVista":
use_tiling, dataset = MathVistaDataset(
max_num_tiles, input_image_path,
use_thumbnail, num_samples_per_partition,
vision_model_type, num_partitions,
) partition_id,
elif task == "MathVista": img_h,
dataset = MathVistaDataset( img_w,
input_image_path, use_tiling,
num_samples_per_partition, max_num_tiles,
num_partitions, use_thumbnail,
partition_id, vision_model_type,
img_h, )
img_w, elif task == "AI2D":
use_tiling, dataset = AI2DDataset(
max_num_tiles, input_image_path,
use_thumbnail, gt_path,
vision_model_type, num_samples_per_partition,
) num_partitions,
elif task == "AI2D": partition_id,
dataset = AI2DDataset( img_h,
input_image_path, img_w,
gt_path, use_tiling,
num_samples_per_partition, max_num_tiles,
num_partitions, use_thumbnail,
partition_id, no_mask=False,
img_h, vision_model_type=vision_model_type,
img_w, )
use_tiling, elif task == "SPDocVQA":
max_num_tiles, keys = {"sample_id": "questionId", "image_id": "image", "question": "question", "answer": "answers"}
use_thumbnail,
no_mask=False, dataset = VQADataset(
vision_model_type=vision_model_type, input_image_path,
) gt_path,
else: num_samples_per_partition,
raise NotImplementedError(f"unsupported task {task}") num_partitions,
partition_id,
return dataset keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == "InfoVQA":
keys = {"sample_id": "questionId", "image_id": "image_local_name", "question": "question", "answer": "answers"}
dataset = VQADataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
else:
raise NotImplementedError(f"unsupported task {task}")
return dataset
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
from torchvision import transforms as T from torchvision import transforms as T
from torchvision.transforms import Compose from torchvision.transforms import Compose
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406]
IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225]
SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5]
SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5]
CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711]
pixel_statistics = { pixel_statistics = {
"clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
} "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
}
def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, vision_model_type="clip"):
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
assert not augment, "Image augmentation not implemented." # Copyright (c) 2023 OpenGVLab.
transform = build_transform(img_h, pixel_mean, pixel_std, vision_model_type) def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
if use_tiling: best_ratio = (1, 1)
assert img_h == img_w, "dynamic tiling expects equal tile height and width" area = width * height
imgs = dynamic_preprocess(img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail) for ratio in target_ratios:
imgs = [transform(img) for img in imgs] target_aspect_ratio = ratio[0] / ratio[1]
else: ratio_diff = abs(aspect_ratio - target_aspect_ratio)
imgs = [transform(img)] if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
return imgs best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685 best_ratio = ratio
# Copyright (c) 2023 OpenGVLab. return best_ratio
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1) def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
area = width * height """
for ratio in target_ratios: Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
target_aspect_ratio = ratio[0] / ratio[1] """
ratio_diff = abs(aspect_ratio - target_aspect_ratio) best_factor = float('-inf')
if ratio_diff < best_ratio_diff: best_ratio = (1, 1)
best_ratio_diff = ratio_diff area = width * height
best_ratio = ratio for ratio in target_ratios:
elif ratio_diff == best_ratio_diff: target_aspect_ratio = ratio[0] / ratio[1]
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: factor_based_on_area_n_ratio = (
best_ratio = ratio min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) *
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio))
return best_ratio if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702 return best_ratio
# Copyright (c) 2023 OpenGVLab.
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size def get_visual_transform(
aspect_ratio = orig_width / orig_height img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False,
vision_model_type="clip", find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
# calculate the existing image aspect ratio pixel_mean, pixel_std = pixel_statistics[vision_model_type]
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if assert not augment, "Image augmentation not implemented."
i * j <= max_num and i * j >= min_num) transform = build_transform(img_h, pixel_mean, pixel_std, vision_model_type)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
if use_tiling:
# find the closest aspect ratio to the target assert img_h == img_w, "dynamic tiling expects equal tile height and width"
target_aspect_ratio = find_closest_aspect_ratio( imgs = dynamic_preprocess(
aspect_ratio, target_ratios, orig_width, orig_height, image_size) img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail,
find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn)
# calculate the target width and height imgs = [transform(img) for img in imgs]
target_width = image_size * target_aspect_ratio[0] else:
target_height = image_size * target_aspect_ratio[1] imgs = [transform(img)]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
return imgs
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = [] # From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702
for i in range(blocks): # Copyright (c) 2023 OpenGVLab.
box = ( def dynamic_preprocess(
(i % (target_width // image_size)) * image_size, image, min_num=1, max_num=6, image_size=448, use_thumbnail=False,
(i // (target_width // image_size)) * image_size, find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
((i % (target_width // image_size)) + 1) * image_size, orig_width, orig_height = image.size
((i // (target_width // image_size)) + 1) * image_size aspect_ratio = orig_width / orig_height
)
# split the image # calculate the existing image aspect ratio
split_img = resized_img.crop(box) target_ratios = set(
processed_images.append(split_img) (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
assert len(processed_images) == blocks i * j <= max_num and i * j >= min_num)
if use_thumbnail and len(processed_images) != 1: target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img) # find the closest aspect ratio to the target
return processed_images target_aspect_ratio = find_closest_aspect_ratio_fn(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 # calculate the target width and height
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 target_width = image_size * target_aspect_ratio[0]
def build_transform(input_size, pixel_mean, pixel_std, vision_model_type): target_height = image_size * target_aspect_ratio[1]
if vision_model_type in ("siglip", "internvit"): blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), # resize the image
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), resized_img = image.resize((target_width, target_height))
T.ToTensor(), processed_images = []
T.Normalize(mean=pixel_mean, std=pixel_std) for i in range(blocks):
]) box = (
elif vision_model_type == "clip": (i % (target_width // image_size)) * image_size,
transform = Compose([ (i // (target_width // image_size)) * image_size,
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), ((i % (target_width // image_size)) + 1) * image_size,
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), ((i // (target_width // image_size)) + 1) * image_size
T.ToTensor(), )
T.Normalize(mean=pixel_mean, std=pixel_std), # split the image
]) split_img = resized_img.crop(box)
else: processed_images.append(split_img)
raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}") assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
return transform thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
def build_transform(input_size, pixel_mean, pixel_std, vision_model_type):
if vision_model_type in ("siglip", "internvit", "radio", "huggingface"):
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std)
])
elif vision_model_type == "clip":
transform = Compose([
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.ToTensor(),
T.Normalize(mean=pixel_mean, std=pixel_std),
])
else:
raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}")
return transform
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch import torch
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
try: try:
from megatron.core.extensions.transformer_engine import ( from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear, TEColumnParallelLinear,
TEDotProductAttention, TEDotProductAttention,
TELayerNormColumnParallelLinear, TELayerNormColumnParallelLinear,
TENorm, TENorm,
TERowParallelLinear, TERowParallelLinear,
) )
HAVE_TE = True HAVE_TE = True
except ImportError: except ImportError:
HAVE_TE = False HAVE_TE = False
try: try:
import apex import apex
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.torch_norm import WrappedTorchNorm from megatron.core.transformer.torch_norm import WrappedTorchNorm
HAVE_APEX = True HAVE_APEX = True
LNImpl = FusedLayerNorm LNImpl = FusedLayerNorm
except ImportError: except ImportError:
import warnings import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm') warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm LNImpl = WrappedTorchNorm
def get_layer_spec(is_vit, normalization) -> ModuleSpec: def get_layer_spec(is_vit, normalization) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal
if normalization == "LayerNorm": if normalization == "LayerNorm":
norm = LNImpl norm = LNImpl
elif normalization == "RMSNorm": elif normalization == "RMSNorm":
if HAVE_TE: if HAVE_TE:
norm = TENorm norm = TENorm
else: else:
version = torch.__version__.split('.') version = torch.__version__.split('.')
version_geq_2_4 = ( version_geq_2_4 = (
int(TORCH_VERSION[0]) > 2 int(TORCH_VERSION[0]) > 2
or ( or (
int(TORCH_VERSION[0]) == 2 int(TORCH_VERSION[0]) == 2
and int(TORCH_VERSION[1]) >= 4 and int(TORCH_VERSION[1]) >= 4
) )
) )
assert version_geq_2_4, "Torch version >= 2.4.0 is required for RMSNorm" assert version_geq_2_4, "Torch version >= 2.4.0 is required for RMSNorm"
if HAVE_APEX: if HAVE_APEX:
warnings.warn(f'Apex does not support RMSNorm. Falling back to Torch Norm') warnings.warn(f'Apex does not support RMSNorm. Falling back to Torch Norm')
norm = WrappedTorchNorm norm = WrappedTorchNorm
else: else:
raise RuntimeError("unknown normalization", normalization) raise RuntimeError("unknown normalization", normalization)
mlp = get_mlp_module_spec(use_te=False) # doesn't include norm. mlp = get_mlp_module_spec(use_te=False) # doesn't include norm.
return ModuleSpec( return ModuleSpec(
module=TransformerLayer, module=TransformerLayer,
submodules=TransformerLayerSubmodules( submodules=TransformerLayerSubmodules(
input_layernorm=norm, input_layernorm=norm,
self_attention=ModuleSpec( self_attention=ModuleSpec(
module=SelfAttention, module=SelfAttention,
params={"attn_mask_type": attn_mask_type}, params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules( submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear, linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention, core_attention=DotProductAttention,
linear_proj=RowParallelLinear, linear_proj=RowParallelLinear,
q_layernorm=IdentityOp, q_layernorm=IdentityOp,
k_layernorm=IdentityOp, k_layernorm=IdentityOp,
), ),
), ),
self_attn_bda=get_bias_dropout_add, self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=norm, pre_mlp_layernorm=norm,
mlp=mlp, mlp=mlp,
mlp_bda=get_bias_dropout_add, mlp_bda=get_bias_dropout_add,
), ),
) )
def get_layer_spec_te(is_vit=False) -> ModuleSpec: def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal
# Padding mask is needed for e.g. Context Parallel.
mlp = get_norm_mlp_module_spec_te() if padding:
return ModuleSpec( assert not is_vit, "padding_causal mask not used with ViT"
module=TransformerLayer, attn_mask_type = AttnMaskType.padding_causal
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec( mlp = get_norm_mlp_module_spec_te()
module=SelfAttention, return ModuleSpec(
params={"attn_mask_type": attn_mask_type}, module=TransformerLayer,
submodules=SelfAttentionSubmodules( submodules=TransformerLayerSubmodules(
linear_qkv=TELayerNormColumnParallelLinear, self_attention=ModuleSpec(
core_attention=TEDotProductAttention, module=SelfAttention,
linear_proj=TERowParallelLinear, params={"attn_mask_type": attn_mask_type},
q_layernorm=IdentityOp, submodules=SelfAttentionSubmodules(
k_layernorm=IdentityOp, linear_qkv=TELayerNormColumnParallelLinear,
), core_attention=TEDotProductAttention,
), linear_proj=TERowParallelLinear,
self_attn_bda=get_bias_dropout_add, q_layernorm=IdentityOp,
pre_mlp_layernorm=IdentityOp, k_layernorm=IdentityOp,
mlp=mlp, ),
mlp_bda=get_bias_dropout_add, ),
), self_attn_bda=get_bias_dropout_add,
) pre_mlp_layernorm=IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: ),
# Dense MLP w/ or w/o TE modules. )
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules( def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, # Dense MLP w/ or w/o TE modules.
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, return ModuleSpec(
), module=MLP,
) submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
def get_norm_mlp_module_spec_te() -> ModuleSpec: ),
return ModuleSpec( )
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear def get_norm_mlp_module_spec_te() -> ModuleSpec:
), return ModuleSpec(
) module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings import warnings
from copy import deepcopy from copy import deepcopy
import torch import torch
from config import get_language_model_config, get_vision_model_config, get_vision_projection_config from config import get_language_model_config, get_vision_model_config, get_vision_projection_config
from layer_specs import get_layer_spec, get_layer_spec_te, get_mlp_module_spec, get_norm_mlp_module_spec_te from layer_specs import get_layer_spec, get_layer_spec_te, get_mlp_module_spec, get_norm_mlp_module_spec_te
from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN, LLaVAModel from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN, LLaVAModel
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.training import get_args, get_tokenizer, print_rank_0 from megatron.training import get_args, get_tokenizer, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args from megatron.training.arguments import core_transformer_config_from_args
def model_provider( def model_provider(
pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True
) -> LLaVAModel: ) -> LLaVAModel:
"""Builds the model. """Builds the model.
Args: Args:
pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True. pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True.
post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True. post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True.
add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder
will live on only a subset of the pipeline stages (specifically, only the first stage). will live on only a subset of the pipeline stages (specifically, only the first stage).
add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder
will live on only a subset of the pipeline stages (specifically, every stage after the first one). will live on only a subset of the pipeline stages (specifically, every stage after the first one).
parallel_output (bool): Enable parallel model output. parallel_output (bool): Enable parallel model output.
Returns: Returns:
model: A multimodal model. model: A multimodal model.
""" """
args = get_args() args = get_args()
assert args.ckpt_format == 'torch', "Only ckpt-format torch is supported for VLM training currently." assert args.encoder_pipeline_model_parallel_size <= 1, "LLaVA does not support pp>1 for encoder on it's own pipeline rank"
assert args.encoder_pipeline_model_parallel_size <= 1, "LLaVA does not support pp>1 for encoder on it's own pipeline rank"
use_te = args.use_te
use_te = args.use_te
print_rank_0('building a multimodal model ...')
print_rank_0('building a multimodal model ...')
num_image_embeddings = get_num_image_embeddings(
num_image_embeddings = get_num_image_embeddings( args.img_h,
args.img_h, args.img_w,
args.img_w, args.patch_dim,
args.patch_dim, args.vision_model_type,
args.vision_model_type, args.disable_vision_class_token,
args.disable_vision_class_token, 1,
1, args.pixel_shuffle,
args.pixel_shuffle, args.use_tile_tags,
args.use_tile_tags, )
) old_seq_length = args.seq_length
old_seq_length = args.seq_length args.seq_length = args.encoder_seq_length = num_image_embeddings
args.seq_length = args.encoder_seq_length = num_image_embeddings if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length:
if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length: warnings.warn(
warnings.warn( f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})"
f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})" )
)
max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings
max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings
assert (
assert ( args.decoder_seq_length is not None
args.decoder_seq_length is not None ), "Please provide --decoder-seq-length to set the language model sequence length"
), "Please provide --decoder-seq-length to set the language model sequence length" assert (
assert ( args.decoder_seq_length > max_num_image_embeddings
args.decoder_seq_length > max_num_image_embeddings ), "Language model sequence length must be greater than the maximum number of image embeddings"
), "Language model sequence length must be greater than the maximum number of image embeddings" if args.decoder_seq_length > args.max_position_embeddings:
if args.decoder_seq_length > args.max_position_embeddings: args.max_position_embeddings = args.decoder_seq_length
args.max_position_embeddings = args.decoder_seq_length warnings.warn(
warnings.warn( f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length"
f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length" )
)
base_config = core_transformer_config_from_args(get_args())
base_config = core_transformer_config_from_args(get_args()) base_config.language_model_type = args.language_model_type
base_config.language_model_type = args.language_model_type base_config.vision_model_type = args.vision_model_type
base_config.vision_model_type = args.vision_model_type base_config.calculate_per_token_loss = True
base_config.calculate_per_token_loss = True
language_config = deepcopy(base_config)
language_config = deepcopy(base_config) language_config = get_language_model_config(language_config)
language_config = get_language_model_config(language_config)
if use_te:
if use_te: # Padding mask needed for SP/CP.
language_transformer_layer_spec = get_layer_spec_te( padding = args.context_parallel_size > 1 and args.sequence_parallel
is_vit=False language_transformer_layer_spec = get_layer_spec_te(
) # TENorm detects LayerNorm/RMS automatically. is_vit=False, padding=padding
else: ) # TENorm detects LayerNorm/RMS automatically.
language_transformer_layer_spec = get_layer_spec( else:
is_vit=False, normalization=language_config.normalization language_transformer_layer_spec = get_layer_spec(
) is_vit=False, normalization=language_config.normalization
)
vision_config = deepcopy(base_config)
vision_config = get_vision_model_config( vision_model_type = args.vision_model_type
vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling vision_config = deepcopy(base_config)
) vision_config = get_vision_model_config(
vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling
vision_model_type = args.vision_model_type )
if vision_model_type in ["clip", "siglip"]: if vision_model_type.startswith("huggingface"):
if use_te: assert args.encoder_tensor_model_parallel_size < 2, "Huggingface vision encoders do not support --encoder-tensor-model-parallel-size > 1"
vision_transformer_layer_spec = get_layer_spec_te( assert args.encoder_pipeline_model_parallel_size == 0, "Huggingface vision encoders do not support --encoder-pipeline-model-parallel-size > 0"
is_vit=True assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel"
) # TENorm detects LayerNorm/RMS automatically. assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1"
else: assert args.vision_huggingface_model_name_or_path is not None, "Providing --vision-huggingface-model-name-or-path is necessary when using huggingface vision model"
vision_transformer_layer_spec = get_layer_spec(
is_vit=True, normalization=vision_config.normalization vision_config.huggingface_model_name_or_path = args.vision_huggingface_model_name_or_path
)
elif vision_model_type == "internvit": from transformers import AutoConfig
from nvlm.internvit import get_internvit_layer_spec huggingface_config = AutoConfig.from_pretrained(vision_config.huggingface_model_name_or_path)
vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te) vision_config.hidden_size = huggingface_config.hidden_size
else:
raise RuntimeError("unsupported vision model type", vision_model_type) vision_model_type = args.vision_model_type
if vision_model_type in ["clip", "siglip", "radio"]:
vision_projection_config = deepcopy(base_config) if use_te:
vision_projection_config = get_vision_projection_config( vision_transformer_layer_spec = get_layer_spec_te(
vision_projection_config, language_config.hidden_size is_vit=True
) ) # TENorm detects LayerNorm/RMS automatically.
else:
# --encoder-pipeline-model-parallel-size 1 will enable a separate pipeline stage for the vision model. vision_transformer_layer_spec = get_layer_spec(
if args.encoder_pipeline_model_parallel_size > 0: is_vit=True, normalization=vision_config.normalization
assert ( )
args.encoder_pipeline_model_parallel_size == 1 elif vision_model_type == "internvit":
), "vision model and projection can only live on 1 pipeline stage." from nvlm.internvit import get_internvit_layer_spec
vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te)
if args.encoder_tensor_model_parallel_size > 0: elif vision_model_type.startswith("huggingface"):
vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size vision_transformer_layer_spec = None
vision_projection_config.tensor_model_parallel_size = ( else:
args.encoder_tensor_model_parallel_size raise RuntimeError("unsupported vision model type", vision_model_type)
)
vision_projection_config = deepcopy(base_config)
# Make sure vision model pipeline parallel size is not inherited from the language model pipeline parallel size.
# 0 is not a valid for the config value, hence max(1, ). if base_config.language_model_type.startswith("huggingface"):
vision_config.pipeline_model_parallel_size = max(1, args.encoder_pipeline_model_parallel_size) assert args.tensor_model_parallel_size == 1, "Huggingface models do not support --tensor-model-parallel-size > 1"
vision_projection_config.pipeline_model_parallel_size = vision_config.pipeline_model_parallel_size assert args.pipeline_model_parallel_size < 2, "Huggingface models do not support --pipeline-model-parallel-size > 1"
assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel"
# Make sure the vision model does not inherit first and last pipeline num layers from the language model. assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1"
vision_config.first_pipeline_num_layers = vision_config.last_pipeline_num_layers = None assert args.language_huggingface_model_name_or_path is not None, "Providing --language-huggingface-model-name-or-path is necessary when using huggingface language model"
if vision_projection_config.normalization: language_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path
vision_projection_layer_spec = get_norm_mlp_module_spec_te().submodules # Pass to vision projection config so can choose the correct ffn hidden size
else: vision_projection_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path
vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules
vision_projection_config = get_vision_projection_config(
# Toggle --recompute* for the vision and language model separately. vision_projection_config, language_config.hidden_size
if args.recompute_vision: )
if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None:
vision_config.recompute_num_layers = vision_config.num_layers # --encoder-pipeline-model-parallel-size 1 will enable a separate pipeline stage for the vision model.
else: if args.encoder_pipeline_model_parallel_size > 0:
vision_config.recompute_granularity = None assert (
vision_config.recompute_method = None args.encoder_pipeline_model_parallel_size == 1
vision_config.recompute_num_layers = None ), "vision model and projection can only live on 1 pipeline stage."
vision_projection_config.recompute_granularity = None if args.encoder_tensor_model_parallel_size > 0:
vision_projection_config.recompute_method = None vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size
vision_projection_config.recompute_num_layers = None vision_projection_config.tensor_model_parallel_size = (
args.encoder_tensor_model_parallel_size
)
tokenizer = get_tokenizer()
image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) # Make sure vision model pipeline parallel size is not inherited from the language model pipeline parallel size.
# 0 is not a valid for the config value, hence max(1, ).
tile_tags = _get_tile_tags(args, tokenizer) vision_config.pipeline_model_parallel_size = max(1, args.encoder_pipeline_model_parallel_size)
vision_projection_config.pipeline_model_parallel_size = vision_config.pipeline_model_parallel_size
model = LLaVAModel(
language_transformer_config=language_config, # Make sure the vision model does not inherit first and last pipeline num layers from the language model.
language_transformer_layer_spec=language_transformer_layer_spec, vision_config.first_pipeline_num_layers = vision_config.last_pipeline_num_layers = None
language_vocab_size=args.padded_vocab_size,
language_max_sequence_length=args.decoder_seq_length, if vision_projection_config.normalization:
vision_transformer_config=vision_config, vision_projection_layer_spec = get_norm_mlp_module_spec_te().submodules
vision_transformer_layer_spec=vision_transformer_layer_spec, else:
drop_vision_class_token=args.disable_vision_class_token, vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules
vision_projection_config=vision_projection_config,
vision_projection_layer_spec=vision_projection_layer_spec, # Toggle --recompute* for the vision and language model separately.
vision_projection_type="mlp", if args.recompute_vision:
allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint, if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None:
parallel_output=parallel_output, vision_config.recompute_num_layers = vision_config.num_layers
language_position_embedding_type=args.position_embedding_type, else:
language_rotary_percent=args.rotary_percent, vision_config.recompute_granularity = None
pre_process=pre_process, vision_config.recompute_method = None
post_process=post_process, vision_config.recompute_num_layers = None
add_encoder=add_encoder,
add_decoder=add_decoder, vision_projection_config.recompute_granularity = None
img_h=args.img_h, vision_projection_config.recompute_method = None
img_w=args.img_w, vision_projection_config.recompute_num_layers = None
patch_dim=args.patch_dim,
language_rotary_base=args.rotary_base, # TODO: Vision model and projection do not use SP/CP yet.
language_rope_scaling=args.use_rope_scaling, vision_config.sequence_parallel = False
image_token_index=image_token_index, vision_config.context_parallel_size = 1
pixel_shuffle=args.pixel_shuffle, vision_config.tp_comm_overlap = False
tile_tags=tile_tags,
) vision_projection_config.sequence_parallel = False
vision_projection_config.context_parallel_size = 1
model.freeze( vision_projection_config.tp_comm_overlap = False
freeze_language_model=args.freeze_LM,
freeze_vision_model=args.freeze_ViT, tokenizer = get_tokenizer()
freeze_vision_projection=False, image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
) assert image_token_index is not None, f"IMAGE_TOKEN={IMAGE_TOKEN} needs to be added using the --special-tokens arg."
return model tile_tags = _get_tile_tags(args, tokenizer)
model = LLaVAModel(
def _get_tile_tags(args, tokenizer): language_transformer_config=language_config,
"""Tile tags are used in NVLM to surround image tiles with text tags.""" language_transformer_layer_spec=language_transformer_layer_spec,
if not args.use_tile_tags: language_vocab_size=args.padded_vocab_size,
return None language_max_sequence_length=args.decoder_seq_length,
vision_transformer_config=vision_config,
# We expect the tokenized length of the tags is same. vision_transformer_layer_spec=vision_transformer_layer_spec,
thumbnail_tag_text = "<tile_global_thumbnail>" drop_vision_class_token=args.disable_vision_class_token,
if args.tokenizer_prompt_format == "nvlm-yi-34b": vision_projection_config=vision_projection_config,
thumbnail_tag_text = "<tile_global>" vision_projection_layer_spec=vision_projection_layer_spec,
vision_projection_type="mlp",
assert args.max_num_tiles <= 6, "Up to 6 tile tags used" allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint,
tile_tags_text = [f"<tile_{i}>" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text] parallel_output=parallel_output,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
start_idx = 0 language_position_embedding_type=args.position_embedding_type,
if tokenizer._prompt_config.has_bos: language_rotary_percent=args.rotary_percent,
start_idx = 1 pre_process=pre_process,
post_process=post_process,
# Convert to tokens [num_tiles, tile_seq_len]. add_encoder=add_encoder,
tile_tags = [tokenizer.tokenize(t)[start_idx:] for t in tile_tags_text] add_decoder=add_decoder,
img_h=args.img_h,
return tile_tags img_w=args.img_w,
patch_dim=args.patch_dim,
language_rotary_base=args.rotary_base,
language_rope_scaling=args.use_rope_scaling,
image_token_index=image_token_index,
pixel_shuffle=args.pixel_shuffle,
tile_tags=tile_tags,
)
model.freeze(
freeze_language_model=args.freeze_LM,
freeze_vision_model=args.freeze_ViT,
freeze_vision_projection=False,
)
return model
def _get_tile_tags(args, tokenizer):
"""Tile tags are used in NVLM to surround image tiles with text tags."""
if not args.use_tile_tags:
return None
# We expect the tokenized length of the tags is same.
thumbnail_tag_text = "<tile_global_thumbnail>"
if args.tokenizer_prompt_format == "nvlm-yi-34b":
thumbnail_tag_text = "<tile_global>"
assert args.max_num_tiles <= 6, "Up to 6 tile tags used"
tile_tags_text = [f"<tile_{i}>" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text]
start_idx = 0
if tokenizer._prompt_config.has_bos:
start_idx = 1
# Convert to tokens [num_tiles, tile_seq_len].
tile_tags = [tokenizer.tokenize(t)[start_idx:] for t in tile_tags_text]
return tile_tags
import argparse import argparse
import os import os
import torch import torch
from transformers import AutoModel from transformers import AutoModel
def convert(model_name, output_path, tensor_parallel_size, use_te): def convert(model_name, output_path, tensor_parallel_size, use_te):
"""Convert InternViT HF checkpoint to mcore.""" """Convert InternViT HF checkpoint to mcore."""
hf_model = AutoModel.from_pretrained( hf_model = AutoModel.from_pretrained(
model_name, model_name,
trust_remote_code=True trust_remote_code=True
) )
hf_state_dict = hf_model.state_dict() hf_state_dict = hf_model.state_dict()
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]
hidden_size = 3200 hidden_size = 3200
num_heads = 25 num_heads = 25
dim = 128 dim = 128
order = torch.ones(3 * hidden_size).long() order = torch.ones(3 * hidden_size).long()
for j in range(num_heads): for j in range(num_heads):
for i in range(dim): for i in range(dim):
order[i + dim*3*j] = j*dim+i order[i + dim*3*j] = j*dim+i
order[dim + i + dim*3*j] = j*dim+i+num_heads*dim order[dim + i + dim*3*j] = j*dim+i+num_heads*dim
order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2 order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2
for name, tensor in hf_state_dict.items(): for name, tensor in hf_state_dict.items():
# Map parameter names to ones used in megatron. # Map parameter names to ones used in megatron.
new_name = "" new_name = ""
new_tensor = tensor new_tensor = tensor
# This is used for chunking some tensors to target tensor parallel size. # This is used for chunking some tensors to target tensor parallel size.
chunk_dim = None chunk_dim = None
if "embeddings.class_embedding" in name: if "embeddings.class_embedding" in name:
new_name = "class_token" new_name = "class_token"
elif "embeddings.patch_embedding.weight" in name: elif "embeddings.patch_embedding.weight" in name:
new_name = "conv1.weight" new_name = "conv1.weight"
elif "embeddings.patch_embedding.bias" in name: elif "embeddings.patch_embedding.bias" in name:
new_name = "conv1.bias" new_name = "conv1.bias"
elif "embeddings.position_embedding" in name: elif "embeddings.position_embedding" in name:
new_name = "position_embeddings.weight" new_name = "position_embeddings.weight"
new_tensor = new_tensor.squeeze(0) new_tensor = new_tensor.squeeze(0)
elif "encoder.layers" in name: elif "encoder.layers" in name:
layer_idx = name.split(".")[2] layer_idx = name.split(".")[2]
base = f"decoder.layers.{layer_idx}" base = f"decoder.layers.{layer_idx}"
head_dim = 128 head_dim = 128
if tensor_parallel_size == 1: if tensor_parallel_size == 1:
num_padded_heads = 25 num_padded_heads = 25
elif tensor_parallel_size == 8: elif tensor_parallel_size == 8:
# Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism. # Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism.
# So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model. # So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model.
num_padded_heads = 32 num_padded_heads = 32
else: else:
raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size) raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size)
if "ls1" in name: if "ls1" in name:
new_name = f"{base}.ls1" new_name = f"{base}.ls1"
elif "ls2" in name: elif "ls2" in name:
new_name = f"{base}.ls2" new_name = f"{base}.ls2"
elif "attn.qkv.weight" in name: elif "attn.qkv.weight" in name:
new_name = f"{base}.self_attention.linear_qkv.weight" new_name = f"{base}.self_attention.linear_qkv.weight"
num_tensors = 3 num_tensors = 3
padded_dim = head_dim * num_padded_heads * num_tensors padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device) padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:new_tensor.shape[0], :] = new_tensor[order] padded_tensor[:new_tensor.shape[0], :] = new_tensor[order]
new_tensor = padded_tensor new_tensor = padded_tensor
chunk_dim = 0 chunk_dim = 0
elif "attn.q_norm.weight" in name: elif "attn.q_norm.weight" in name:
new_name = f"{base}.self_attention.q_layernorm.weight" new_name = f"{base}.self_attention.q_layernorm.weight"
num_tensors = 1 num_tensors = 1
padded_dim = head_dim * num_padded_heads * num_tensors padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:new_tensor.shape[0]] = new_tensor padded_tensor[:new_tensor.shape[0]] = new_tensor
new_tensor = padded_tensor new_tensor = padded_tensor
chunk_dim = 0 chunk_dim = 0
elif "attn.k_norm.weight" in name: elif "attn.k_norm.weight" in name:
new_name = f"{base}.self_attention.k_layernorm.weight" new_name = f"{base}.self_attention.k_layernorm.weight"
num_tensors = 1 num_tensors = 1
padded_dim = head_dim * num_padded_heads * num_tensors padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:new_tensor.shape[0]] = new_tensor padded_tensor[:new_tensor.shape[0]] = new_tensor
new_tensor = padded_tensor new_tensor = padded_tensor
chunk_dim = 0 chunk_dim = 0
elif "attn.proj.weight" in name: elif "attn.proj.weight" in name:
new_name = f"{base}.self_attention.linear_proj.weight" new_name = f"{base}.self_attention.linear_proj.weight"
num_tensors = 1 num_tensors = 1
padded_dim = head_dim * num_padded_heads * num_tensors padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device) padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:, :new_tensor.shape[-1]] = new_tensor padded_tensor[:, :new_tensor.shape[-1]] = new_tensor
new_tensor = padded_tensor new_tensor = padded_tensor
chunk_dim = 1 chunk_dim = 1
elif "attn.proj.bias" in name: elif "attn.proj.bias" in name:
new_name = f"{base}.self_attention.linear_proj.bias" new_name = f"{base}.self_attention.linear_proj.bias"
elif "mlp.fc1.weight" in name: elif "mlp.fc1.weight" in name:
new_name = f"{base}.mlp.linear_fc1.weight" new_name = f"{base}.mlp.linear_fc1.weight"
chunk_dim = 0 chunk_dim = 0
elif "mlp.fc1.bias" in name: elif "mlp.fc1.bias" in name:
new_name = f"{base}.mlp.linear_fc1.bias" new_name = f"{base}.mlp.linear_fc1.bias"
chunk_dim = 0 chunk_dim = 0
elif "mlp.fc2.weight" in name: elif "mlp.fc2.weight" in name:
new_name = f"{base}.mlp.linear_fc2.weight" new_name = f"{base}.mlp.linear_fc2.weight"
chunk_dim = 1 chunk_dim = 1
elif "mlp.fc2.bias" in name: elif "mlp.fc2.bias" in name:
new_name = f"{base}.mlp.linear_fc2.bias" new_name = f"{base}.mlp.linear_fc2.bias"
elif "norm1" in name: elif "norm1" in name:
new_name = f"{base}.input_layernorm.weight" new_name = f"{base}.input_layernorm.weight"
elif "norm2" in name: elif "norm2" in name:
new_name = f"{base}.pre_mlp_layernorm.weight" new_name = f"{base}.pre_mlp_layernorm.weight"
else: else:
raise RuntimeError("unexpected transformer layer name", name) raise RuntimeError("unexpected transformer layer name", name)
else: else:
raise RuntimeError("unexpected layer name", name) raise RuntimeError("unexpected layer name", name)
assert new_name != "", f"unexpected layer name {name}" assert new_name != "", f"unexpected layer name {name}"
# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
is_extra_state_layer = any([l in new_name for l in extra_state_layers]) is_extra_state_layer = any([l in new_name for l in extra_state_layers])
if use_te and is_extra_state_layer: if use_te and is_extra_state_layer:
layer = new_name.split(".")[-2] layer = new_name.split(".")[-2]
if layer in extra_state_layers: if layer in extra_state_layers:
extra_state_name = ( extra_state_name = (
new_name[: new_name.rfind(".") + 1] + "_extra_state" new_name[: new_name.rfind(".") + 1] + "_extra_state"
) # Replace the weight name. ) # Replace the weight name.
for i in range(tensor_parallel_size): for i in range(tensor_parallel_size):
new_state_dicts[i]["model"][extra_state_name] = None new_state_dicts[i]["model"][extra_state_name] = None
if chunk_dim is None: if chunk_dim is None:
new_tensors = [new_tensor for _ in range(tensor_parallel_size)] new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
else: else:
new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)
for i in range(tensor_parallel_size): for i in range(tensor_parallel_size):
new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()
for i in range(tensor_parallel_size): for i in range(tensor_parallel_size):
output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}") output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}")
os.makedirs(output_dir_tp, exist_ok=True) os.makedirs(output_dir_tp, exist_ok=True)
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
torch.save(new_state_dicts[i], output_path_tp) torch.save(new_state_dicts[i], output_path_tp)
print("saved file", output_path_tp) print("saved file", output_path_tp)
print("done") print("done")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter") parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter")
parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace") parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace")
parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.") parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.")
parser.add_argument("--use-te", action="store_true", default=True) parser.add_argument("--use-te", action="store_true", default=True)
parser.add_argument("--tensor-parallel-size", type=int, required=True) parser.add_argument("--tensor-parallel-size", type=int, required=True)
args = parser.parse_args() args = parser.parse_args()
convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te) convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment