Commit 688448db authored by silencealiang's avatar silencealiang
Browse files

更新代码

parent a02a5490
Pipeline #2503 passed with stage
File mode changed from 100644 to 100755
......@@ -7,11 +7,11 @@ do
fi
done
mpirun -np 256 --hostfile gptnodes \
mpirun -np 32 --hostfile hostfile_mixtral_8x7B \
--allow-run-as-root \
--bind-to none \
--mca plm_rsh_no_tree_spawn 1 \
train_GPT-MOE_567B.sh node002 --profiling=$profiling > output.log 2>&1
train_mixtral_8x7B_multinodes.sh node066 --profiling=$profiling > output.log 2>&1
wait
......
......@@ -4,18 +4,23 @@ for para in $*
do
if [[ $para == --profiling* ]];then
profiling=${para#*=}
export GPU_FLUSH_ON_EXECUTION=1
export HIP_DIRECT_DISPATCH=0
fi
done
source /opt/dtk/env.sh
# Runs Mixtral 8x7B model
source /opt/dtk/env.sh
# defauat env
CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR}))
export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH
export GLOG_minloglevel=3
export CUDA_DEVICE_MAX_CONNECTIONS=1
export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1
export GPU_MAX_HW_QUEUES=10
# nccl env
export NCCL_ALGO=Ring
export NCCL_MIN_NCHANNELS=32
export NCCL_MAX_NCHANNELS=32
......@@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7
export NCCL_NET_GDR_READ=1
export RCCL_SDMA_COPY_ENABLE=0
export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
#export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml"
export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml"
# enable BatchLinear
export GROUPED_GEMM_BatchLinear=1
export GLOG_minloglevel=3
RANK=$OMPI_COMM_WORLD_RANK
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
......@@ -75,7 +81,7 @@ MOE_ARGS=(
--moe-token-dispatcher-type alltoall
--moe-expert-capacity-factor 0.5
--moe-pad-expert-input-to-capacity
--moe-grouped-gemm
#--moe-grouped-gemm
)
DATA_ARGS=(
......@@ -103,25 +109,17 @@ TRAINING_ARGS=(
TORCH_PROFIE_ARGS=(
--profile
--profile-ranks 0 1 2 3 4 5 6 7 8
--profile-ranks 0 1 2 3 4 5 6 7
--profile-step-start 3
--profile-step-end 4
--profile-dir torch_prof_mixtral_1nodes
--profile-dir torch_prof_mixtral_1nodes_tp2-pp1-ep8-ep_tp1
--use-pytorch-profiler
)
HIP_PROFIE_ARGS=(
--profile
--profile-ranks 0 1 2 3 4 5 6 7 8
--profile-step-start 4
--profile-step-end 5
--use-hip-profiler
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 2
--pipeline-model-parallel-size 1
--expert-model-parallel-size 2
--expert-model-parallel-size 8
--expert-tensor-parallel-size 1
--use-distributed-optimizer
--sequence-parallel
......@@ -159,10 +157,6 @@ APP="python3 -u pretrain_gpt.py \
if [[ $profiling == "torch" ]]; then
APP+=" ${TORCH_PROFIE_ARGS[@]}"
elif [[ $profiling == "hip" ]]; then
mkdir -p hip_prof_data
APP+=" ${HIP_PROFIE_ARGS[@]}"
APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}"
fi
#for hygon cpu
......
......@@ -4,18 +4,23 @@ for para in $*
do
if [[ $para == --profiling* ]];then
profiling=${para#*=}
export GPU_FLUSH_ON_EXECUTION=1
export HIP_DIRECT_DISPATCH=0
fi
done
source /opt/dtk/env.sh
# Runs Mixtral 8x7B model
source /opt/dtk/env.sh
# defauat env
CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR}))
export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH
export GLOG_minloglevel=3
export CUDA_DEVICE_MAX_CONNECTIONS=1
export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1
export GPU_MAX_HW_QUEUES=10
# nccl env
export NCCL_ALGO=Ring
export NCCL_MIN_NCHANNELS=32
export NCCL_MAX_NCHANNELS=32
......@@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7
export NCCL_NET_GDR_READ=1
export RCCL_SDMA_COPY_ENABLE=0
export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
#export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml"
export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml"
# enable BatchLinear
export GROUPED_GEMM_BatchLinear=1
export GLOG_minloglevel=3
RANK=$OMPI_COMM_WORLD_RANK
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
......@@ -99,9 +105,6 @@ TRAINING_ARGS=(
--bf16
--overlap-param-gather
--overlap-grad-reduce
--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1
)
TORCH_PROFIE_ARGS=(
......@@ -109,23 +112,15 @@ TORCH_PROFIE_ARGS=(
--profile-ranks 0 1 2 3 8 9 10 11
--profile-step-start 3
--profile-step-end 4
--profile-dir torch_prof_data_mixtral_2nodes
--profile-dir torch_prof_mixtral_4nodes_tp2-pp8-ep2-ep_tp1
--use-pytorch-profiler
)
HIP_PROFIE_ARGS=(
--profile
--profile-ranks 0 1 2 3 8 9 10 11
--profile-step-start 4
--profile-step-end 5
--use-hip-profiler
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 4
--pipeline-model-parallel-size 4
--tensor-model-parallel-size 2
--pipeline-model-parallel-size 8
--expert-model-parallel-size 2
--expert-tensor-parallel-size 2
--expert-tensor-parallel-size 1
--use-distributed-optimizer
--sequence-parallel
)
......@@ -162,10 +157,6 @@ APP="python3 -u pretrain_gpt.py \
if [[ $profiling == "torch" ]]; then
APP+=" ${TORCH_PROFIE_ARGS[@]}"
elif [[ $profiling == "hip" ]]; then
mkdir -p hip_prof_data
APP+=" ${HIP_PROFIE_ARGS[@]}"
APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}"
fi
#for hygon cpu
......
#!/bin/bash
# Runs Mixtral 8x7B model
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6000"}
NNODES=${SLURM_NNODES:-"1"}
NODE_RANK=${RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
CHECKPOINT_PATH=$1
TOKENIZER_MODEL=$2
DATA_PATH=$3
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NNODES
--node_rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
MODEL_ARGS=(
--use-mcore-models
--disable-bias-linear
--seq-length 4096
--max-position-embeddings 32768
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 14336
--num-attention-heads 32
--init-method-std 0.01
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--group-query-attention
--num-query-groups 8
--no-masked-softmax-fusion
--no-position-embedding
--rotary-base 1000000
)
MOE_ARGS=(
--num-experts 8
--moe-router-topk 2
--moe-router-load-balancing-type aux_loss
--moe-aux-loss-coeff 1e-2
--moe-grouped-gemm
--moe-token-dispatcher-type alltoall
--overlap-param-gather
--overlap-grad-reduce
)
DATA_ARGS=(
--tokenizer-type Llama2Tokenizer
--tokenizer-model ${TOKENIZER_MODEL}
--data-path $DATA_PATH
--split 99990,8,2
)
TRAINING_ARGS=(
--micro-batch-size 1
--global-batch-size 256
--lr 1e-4
--train-iters 500000
--lr-decay-iters 320000
--lr-decay-style cosine
--min-lr 1.0e-5
--weight-decay 0.1
--lr-warmup-iters 500
--clip-grad 1.0
--bf16
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 4
--expert-model-parallel-size 8
--use-distributed-optimizer
--sequence-parallel
)
LOGGING_ARGS=(
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \
--no-load-optim \
--no-load-rng
)
if [ -n "${WANDB_API_KEY}" ]; then
LOGGING_ARGS+=(
--wandb-project ${WANDB_PROJECT:-"Mixtral"}
--wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"}
)
fi
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]}
......@@ -20,6 +20,32 @@ def get_language_model_config(config):
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336
elif config.language_model_type == "llama3.1_8b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm
)
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336
elif config.language_model_type == "llama3.1_70B":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm
)
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 28672
elif config.language_model_type == "mistral_7b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
......@@ -74,6 +100,22 @@ def get_language_model_config(config):
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 29568
elif config.language_model_type == "llama3.2_1b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm
)
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 8192
elif config.language_model_type.startswith("huggingface"):
# Loaded from HuggingFace config file.
pass
else:
raise ValueError(f"unknown language model type {config.language_model_type}")
......@@ -125,9 +167,8 @@ def get_vision_model_config(config, apply_query_key_layer_scaling):
config.layernorm_epsilon = 1e-6
elif config.vision_model_type == "internvit":
config.num_layers = 45
config.num_attention_heads = 32 # Padded for TP=8.
config.num_query_groups = 32 # Padded for TP=8.
config.kv_channels = 128
config.num_attention_heads = ((24 // config.tensor_model_parallel_size) + 1) * config.tensor_model_parallel_size
config.num_query_groups = config.num_attention_heads
config.add_bias_linear = True
config.add_qkv_bias = False
config.hidden_size = 3200
......@@ -144,6 +185,29 @@ def get_vision_model_config(config, apply_query_key_layer_scaling):
config.normalization = 'RMSNorm'
config.layernorm_epsilon = 1e-6
config.apply_rope_fusion = False
elif config.vision_model_type == "radio":
config.num_layers = 32
config.num_attention_heads = 16
config.add_bias_linear = True
config.add_qkv_bias = True
config.hidden_size = 1280
config.ffn_hidden_size = 5120
config.gated_linear_unit = False
config.activation_func = fast_gelu
config.kv_channels = 80
config.num_query_groups = 16
config.layernorm_zero_centered_gamma = False
config.apply_query_key_layer_scaling = apply_query_key_layer_scaling
config.bias_activation_fusion = False
config.bias_dropout_fusion = False
config.attention_softmax_in_fp32 = True
config.normalization = 'LayerNorm'
config.apply_rope_fusion = False
config.qk_layernorm = False
config.layernorm_epsilon = 1e-6
elif config.vision_model_type.startswith("huggingface"):
# Loaded from HuggingFace config file.
pass
else:
raise ValueError(f"unknown vision model type {config.vision_model_type}")
......@@ -158,6 +222,12 @@ def get_vision_projection_config(config, hidden_size):
if config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "llama3.1_8b":
config.ffn_hidden_size = 4096
config.activation_func = torch.nn.functional.gelu
config.layernorm_epsilon = 1e-5
config.add_bias_linear = True
config.normalization = "LayerNorm"
elif config.language_model_type == "mistral_7b":
config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.gelu
......@@ -173,6 +243,16 @@ def get_vision_projection_config(config, hidden_size):
config.ffn_hidden_size = 29568
config.normalization = "LayerNorm"
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "llama3.2_1b":
config.ffn_hidden_size = 2048
config.activation_func = torch.nn.functional.gelu
config.normalization = "LayerNorm"
elif config.language_model_type.startswith("huggingface"):
config.activation_func = torch.nn.functional.gelu
from transformers import AutoConfig
hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path)
if "qwen" in hf_config.model_type:
config.ffn_hidden_size = 1536
else:
raise ValueError(f"unknown language model type {config.language_model_type}")
......
......@@ -8,12 +8,13 @@ import traceback
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from image_processing import get_visual_transform
from image_processing import find_closest_aspect_ratio, find_closest_area_weighted_aspect_ratio, get_visual_transform
from PIL import Image
from torchvision.transforms import ToPILImage
import numpy as np
import torch
from energon_util import OfflineTargetAspectRatioSample, SampleListSample
from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.energon import (
......@@ -177,11 +178,15 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
self.txt_to_token_dict = {}
self.img_h, self.img_w = self.args.img_h, self.args.img_w
self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
# This map is used to reduce the number of tiles used per image if the number of tokens is
# larger than the decoder_seq_length.
self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1}
self.find_closest_aspect_ratio_fn = (
find_closest_area_weighted_aspect_ratio if self.args.use_area_weighted_aspect_ratio
else find_closest_aspect_ratio)
def _get_total_seq_length(self, input_ids, num_tiles):
"""Calculate expected sequence length given text tokens length and number of tiles."""
total_num_images = len(num_tiles)
......@@ -227,6 +232,10 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
yield self.encode_llava_sft(sample)
elif isinstance(sample, MultiChoiceVQASample):
yield self.encode_any_single_turn_vqa(sample)
# Because the SampleListSample is defined in the Megatron module but loaded by the Energon
# library, we need to resort to the more brittle check:
elif type(sample).__name__ == "SampleListSample":
yield self.encode_sample_list(sample)
else:
raise NotImplementedError("Sample format not supported", sample)
......@@ -236,7 +245,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
imgs = get_visual_transform(
sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment,
self.args.vision_model_type,
self.args.vision_model_type, find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn
)
num_tiles = [len(imgs)]
......@@ -282,7 +291,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
imgs = get_visual_transform(
sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment,
self.args.vision_model_type,
self.args.vision_model_type, find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn
)
num_tiles = [len(imgs)]
......@@ -310,19 +319,44 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
total_len=self._get_total_seq_length(input_ids, num_tiles),
)
def encode_llava_sft(self, sample: SimilarityInterleavedSample):
def encode_sample_list(self, samples: SampleListSample):
"""We encode the list of samples using encode_llava_sft on each sample."""
error_msg = ("You probably don't want to use online packing since SampleListSample is "
"usually used along offline packing.")
assert not self.is_packing_enabled, error_msg
encoded_samples = []
current_length = 0
for sample in samples.samples:
encoded_sample = self.encode_llava_sft(sample, truncate_for_sample_list_packing=True)
if current_length + encoded_sample.total_len > self.packing_seq_length:
break
else:
encoded_samples.append(encoded_sample)
current_length += encoded_sample.total_len
return self.pack_selected_samples(encoded_samples)
def encode_llava_sft(self, sample: Union[SimilarityInterleavedSample, OfflineTargetAspectRatioSample], truncate_for_sample_list_packing=False):
"""Encode SFT sample."""
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False
# If the target aspect ratio are provided by the dataset, we use them instead of computing
# them with the self.find_closest_aspect_ratio_fn function.
local_find_closest_aspect_ratio_fn = self.find_closest_aspect_ratio_fn
if type(sample).__name__ == "OfflineTargetAspectRatioSample":
target_aspect_ratio = tuple(sample.target_aspect_ratio[0])
assert target_aspect_ratio is not None, "Sample of type OfflineTargetAspectRatioSample needs to define the target aspect ratio."
local_find_closest_aspect_ratio_fn = lambda *args, **kwargs: target_aspect_ratio
has_image = False
if hasattr(sample, "images"):
# We infer whether the sample has image or not.
if hasattr(sample, "images") and not has_video:
# If this is a text-only sample and we are freezing the LM,
# then use a dummy input image.
if len(sample.images) == 0 and self.args.freeze_LM:
empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255))
sample.images.append(empty_img)
if len(sample.images) > 0 and not has_video:
if len(sample.images) > 0:
has_image = True
# Note: Some tokenizers may ignore the system prompt.
......@@ -343,10 +377,10 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
image_tag_ids = [int(x) - 1 for x in re.findall(r"<image-(\d+)>", turn["content"])]
image_tag_ids_list.extend(image_tag_ids)
turn["content"] = re.sub(r"<image-\d+>", IMAGE_TOKEN, turn["content"])
number_image_tags += turn["content"].count(IMAGE_TOKEN)
# For videos, we replace the image tag with the video tag
# For videos, we use the image token to locate where to put the frames.
if has_video:
turn["content"] = turn["content"].replace(IMAGE_TOKEN, VIDEO_TOKEN)
turn["content"] = turn["content"].replace(VIDEO_TOKEN, IMAGE_TOKEN)
number_image_tags += turn["content"].count(IMAGE_TOKEN)
# We re-order the images in sample.images according to how they appear in the conversation.
if len(image_tag_ids_list) > 0:
......@@ -354,10 +388,11 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
# If there is only one image, but several image tags, we assume all the tags refer to the
# same image and duplicate the image:
if len(sample.images) == 1 and number_image_tags > 1:
if not has_video and len(sample.images) == 1 and number_image_tags > 1:
sample.images = sample.images * number_image_tags
number_of_images = len(sample.images)
# We currently only support one video per sample.
number_of_images = 1 if has_video else len(sample.images)
# Fail if there are more image or video tags than image or videos:
error_msg = (
f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}")
......@@ -368,8 +403,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
if number_image_tags < number_of_images:
for turn in conversation:
if turn["role"] == "user":
tag_to_add = VIDEO_TOKEN if has_video else IMAGE_TOKEN
turn["content"] = tag_to_add*(number_of_images-number_image_tags) + "\n" + turn["content"]
turn["content"] = IMAGE_TOKEN*(number_of_images-number_image_tags) + "\n" + turn["content"]
break
input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False)
......@@ -389,12 +423,13 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
for img in sample.images:
img_tiles = get_visual_transform(
img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
self.args.use_thumbnail, augment, self.args.vision_model_type,
find_closest_aspect_ratio_fn=local_find_closest_aspect_ratio_fn)
imgs += img_tiles
num_tiles += [len(img_tiles)]
if max_num_tiles == 1:
break
if sum(num_tiles) * self.token_per_img_tile > max_image_token_allowed:
if sum(num_tiles) * self.num_image_embeddings_per_tile > max_image_token_allowed:
if max_num_tiles in self.num_tiles_degradation_map:
max_num_tiles = self.num_tiles_degradation_map[max_num_tiles]
else:
......@@ -408,7 +443,9 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
use_tiling=False
# Grab the selected frames of the video as a tensor with shape
# fhwc: (num_frames, num_channels, height, width).
video_fchw = sample.images[0].permute(0, 1, 2, 3)
video_fchw = sample.images.frames
if video_fchw.shape[0] == 0:
raise ValueError(f"Video {sample.__key__} {sample.__restore_key__} {sample.texts} has no frames.")
selected_frames = torch.linspace(
0, video_fchw.shape[0] - 1, self.args.num_frames).long()
video_fchw = video_fchw[selected_frames]
......@@ -418,12 +455,13 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
video_chw = to_pil(video_chw)
imgs += get_visual_transform(
video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
self.args.use_thumbnail, augment, self.args.vision_model_type,
find_closest_aspect_ratio_fn=local_find_closest_aspect_ratio_fn)
num_tiles = [len(imgs)]
else:
imgs = num_tiles = []
if self.is_packing_enabled:
if self.is_packing_enabled or truncate_for_sample_list_packing:
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)
# Some final checks with respect to the number of image tokens and images on the tokenized
......@@ -438,6 +476,9 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.")
assert np.sum(num_tiles) == len(imgs), error_msg
# We need to ensure that there are at least some trainable tokens in the sample.
assert self.target_has_trainable_tokens(input_ids, num_tiles, target), "Sample has no trainable tokens."
return ImageTaskSample(
__key__=sample.__key__,
__restore_key__=sample.__restore_key__,
......@@ -450,6 +491,54 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
total_len=self._get_total_seq_length(input_ids, num_tiles),
)
def target_has_trainable_tokens(self, input_ids, num_tiles, target):
# Compute the loss mask based on extending the image tags with the proper
# number of image tokens, extracting the first self.args.decoder_seq_length tokens, and
# ensuring that some of these tokens have a loss mask > 0.
# Note that this is a bit hacky because we reproduce here parts of the logics which are in
# the model itself. Ideally, the data sampler would return the already processed inputs
# and targets to avoid this duplication.
expanded_target = target.copy()
expanded_target[input_ids==self.img_token_id] = self.img_token_id
expanded_target = self.replace_value_with_repetition(
expanded_target, self.img_token_id,
self.num_image_embeddings_per_tile * np.array(num_tiles), IGNORE_INDEX)
loss_mask = torch.ones(torch.tensor(expanded_target).size(), dtype=torch.float)
loss_mask[expanded_target == self.tokenizer.pad] = 0.0 # mask paddings
loss_mask[expanded_target == IGNORE_INDEX] = 0.0 # mask prompts
loss_mask = torch.cat((loss_mask[1:], torch.zeros((1,))))
loss_mask = loss_mask[:self.args.decoder_seq_length]
return torch.sum(loss_mask) > 0
def replace_value_with_repetition(self, arr, token_to_replace, num_repetition, new_token):
"""
Replace every occurrence of value V in the input array with R repetitions of W.
Args:
arr (Array): Input array to be modified
token_to_replace: token to be replaced
new_token: new token
num_repetition (Array): number of repetition of new token.
Returns:
Array: New array with token_to_replace replaced by num_repetition repetitions of
new_token
"""
error_msg = "The number of image tokens must match the length of the tile tensor."
assert np.sum(arr==token_to_replace) == len(num_repetition), error_msg
result = []
idx = 0
for item in arr:
if item == token_to_replace:
# If the current item matches token_to_replace, add R copies of W
result.extend([new_token] * num_repetition[idx])
idx += 1
else:
# Otherwise, keep the original item
result.append(item)
return np.array(result)
def encode_any_single_turn_vqa(self, sample):
"""Encode MultiChoiceVQA or VQA sample."""
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
......@@ -467,11 +556,13 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
imgs += get_visual_transform(
video_frame_hwc, self.img_h, self.img_w,
self.args.use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
self.args.use_thumbnail, augment, self.args.vision_model_type,
find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn)
else:
imgs = get_visual_transform(
sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type,
find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn
)
num_tiles = [len(imgs)]
......@@ -545,6 +636,7 @@ class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked,
imgs = get_visual_transform(
sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type,
find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn
)
num_tiles = [len(imgs)]
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
import warnings
from dataclasses import dataclass
from typing import Any, List
from megatron.energon import Sample
from megatron.energon.epathlib.epath import EPath
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@dataclass
class SampleListSample(Sample):
"""Sample type for a list of samples of any type which needs to be packed together.
This is useful for datasets which are packed offline.
"""
#: The images of the sequence
samples: List[Any]
class SampleListWebdataset(DefaultDecoderWebdatasetFactory[SampleListSample]):
__sample_type__ = SampleListSample
def __init__(self, path: EPath, **kwargs):
warnings.warn(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content",
DeprecationWarning,
)
super().__init__(path, **kwargs)
@dataclass
class OfflineTargetAspectRatioSample(Sample):
"""Sample type for image + text samples with target aspect ratio computed offline."""
#: The images of the sequence
images: List[torch.Tensor]
#: The texts of the sequence
texts: List[str]
target_aspect_ratio: List[List]
import argparse
import json
from evaluate_vqav2 import compute_vqa_accuracy
from evaluate_mmmu import get_input_output_paths
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="InfoVQA")
results = []
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(
{
"question_id": res["sample_id"],
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
)
# Make order deterministic.
# results = sorted(results, key=lambda d: d["question_id"])
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
return output_file_path
def infovqa_eval(input_path):
"""Run InfoVQA evaluation."""
result_file_path = merge_input_files(input_path)
return compute_vqa_accuracy(result_file_path, task="InfoVQA")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()
avg_acc = infovqa_eval(args.input_path)
print(f"===== InfoVQA Accuracy {avg_acc:.2f}% =====")
import argparse
import json
from evaluate_vqav2 import compute_vqa_accuracy
from evaluate_mmmu import get_input_output_paths
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="SPDocVQA")
results = []
for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(
{
"question_id": res["sample_id"],
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
)
# Make order deterministic.
# results = sorted(results, key=lambda d: d["question_id"])
with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
return output_file_path
def spdocvqa_eval(input_path):
"""Run SPDocVQA evaluation."""
result_file_path = merge_input_files(input_path)
return compute_vqa_accuracy(result_file_path, task="SPDocVQA")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()
avg_acc = spdocvqa_eval(args.input_path)
print(f"===== SPDocVQA Accuracy {avg_acc:.2f}% =====")
......@@ -4,6 +4,55 @@ import json
from evaluate_mmmu import get_input_output_paths
from open_flamingo.eval.vqa_metric import VQAEval
# ANLS score calculation based on https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/dist.py#L1
# and https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/score.py#L6
# MIT License. Copyright (c) 2022 Shunsuke KITADA
def levenshtein_distance(s1: str, s2: str) -> int:
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = list(range(len(s1) + 1))
for i2, c2 in enumerate(s2):
dists = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
dists.append(distances[i1])
else:
dists.append(1 + min((distances[i1], distances[i1 + 1], dists[-1])))
distances = dists
return distances[-1]
def normalized_levenshtein_distance(s1: str, s2: str) -> float:
dist = levenshtein_distance(s1, s2)
length = max(len(s1.upper()), len(s2.upper()))
return 0.0 if length == 0 else dist / length
def similarity_function(prediction: str, gold_label: str, threshold: float) -> float:
nl_score = normalized_levenshtein_distance(prediction, gold_label)
return 1 - nl_score if nl_score < threshold else 0.0
def anls_score(
prediction: str, gold_labels: List[str], threshold: float = 0.5
) -> float:
# not case sensitive, but space sensitive
y_pred = " ".join(prediction.strip().lower().split())
anls_scores: List[float] = []
for gold_label in gold_labels:
# not case sensitive, but space sensitive
y_true = " ".join(gold_label.strip().lower().split())
anls_score = similarity_function(y_pred, y_true, threshold)
anls_scores.append(anls_score)
score = max(anls_scores)
return score
def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
......@@ -80,6 +129,9 @@ def compute_vqa_accuracy(result_file, task):
num_match = sum([pred == ans for ans in gt])
acc = min(1.0, num_match / 3.0)
all_acc.append(acc)
elif task in ("SPDocVQA", "InfoVQA"):
acc = anls_score(prediction=pred, gold_labels=gt, threshold=0.5)
all_acc.append(acc)
elif task == "AI2D":
assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}"
acc = pred == gt[0]
......
......@@ -400,7 +400,7 @@ class MMMUDataset(torch.utils.data.Dataset):
)
class VideoMMMEDataset(torch.utils.data.Dataset):
class VideoMMEDataset(torch.utils.data.Dataset):
"Video MME evaluation dataset."
def __init__(
......@@ -442,7 +442,7 @@ class VideoMMMEDataset(torch.utils.data.Dataset):
self._ground_truth = ground_truth
self._img_h = img_h
self._img_w = img_w
self._use_tiling = use_tiling
self._use_tiling = False
self._max_num_tiles = max_num_tiles
self._use_thumbnail = use_thumbnail
self._num_frames = num_frames
......@@ -463,20 +463,14 @@ class VideoMMMEDataset(torch.utils.data.Dataset):
if self._num_frames == 1:
video_frames = video_frames[None]
imgs = list(
itertools.chain.from_iterable(
get_visual_transform(
img,
self._img_h,
self._img_w,
self._use_tiling,
self._max_num_tiles,
self._use_thumbnail,
augment=False,
vision_model_type=self._vision_model_type,
)
for img in video_frames
)
imgs = []
for img in video_frames:
from torchvision.transforms import ToPILImage
to_pil = ToPILImage()
img = to_pil(img)
imgs += get_visual_transform(
img, self._img_h, self._img_w, self._use_tiling, self._max_num_tiles,
self._use_thumbnail, augment=False, vision_model_type=self._vision_model_type
)
for question in gt["questions"]:
......@@ -858,7 +852,7 @@ def get_evaluation_dataset(
vision_model_type=vision_model_type,
)
elif task == "VideoMME":
dataset = VideoMMMEDataset(
dataset = VideoMMEDataset(
input_image_path,
gt_path,
num_samples_per_partition,
......@@ -914,6 +908,40 @@ def get_evaluation_dataset(
no_mask=False,
vision_model_type=vision_model_type,
)
elif task == "SPDocVQA":
keys = {"sample_id": "questionId", "image_id": "image", "question": "question", "answer": "answers"}
dataset = VQADataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
elif task == "InfoVQA":
keys = {"sample_id": "questionId", "image_id": "image_local_name", "question": "question", "answer": "answers"}
dataset = VQADataset(
input_image_path,
gt_path,
num_samples_per_partition,
num_partitions,
partition_id,
keys,
img_h,
img_w,
use_tiling,
max_num_tiles,
use_thumbnail,
vision_model_type,
)
else:
raise NotImplementedError(f"unsupported task {task}")
......
......@@ -16,25 +16,11 @@ pixel_statistics = {
"clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
"internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD),
"radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD),
"huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD),
}
def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, vision_model_type="clip"):
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
assert not augment, "Image augmentation not implemented."
transform = build_transform(img_h, pixel_mean, pixel_std, vision_model_type)
if use_tiling:
assert img_h == img_w, "dynamic tiling expects equal tile height and width"
imgs = dynamic_preprocess(img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail)
imgs = [transform(img) for img in imgs]
else:
imgs = [transform(img)]
return imgs
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685
# Copyright (c) 2023 OpenGVLab.
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
......@@ -50,13 +36,52 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
"""
Find the best number of tiles based on the aspect ratio and the area covered by the tiles.
"""
best_factor = float('-inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
factor_based_on_area_n_ratio = (
min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) *
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio))
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def get_visual_transform(
img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False,
vision_model_type="clip", find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
assert not augment, "Image augmentation not implemented."
transform = build_transform(img_h, pixel_mean, pixel_std, vision_model_type)
if use_tiling:
assert img_h == img_w, "dynamic tiling expects equal tile height and width"
imgs = dynamic_preprocess(
img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail,
find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn)
imgs = [transform(img) for img in imgs]
else:
imgs = [transform(img)]
return imgs
# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702
# Copyright (c) 2023 OpenGVLab.
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
def dynamic_preprocess(
image, min_num=1, max_num=6, image_size=448, use_thumbnail=False,
find_closest_aspect_ratio_fn=find_closest_aspect_ratio):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
......@@ -67,7 +92,7 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
target_aspect_ratio = find_closest_aspect_ratio_fn(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
......@@ -98,7 +123,7 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai
# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79
# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276
def build_transform(input_size, pixel_mean, pixel_std, vision_model_type):
if vision_model_type in ("siglip", "internvit"):
if vision_model_type in ("siglip", "internvit", "radio", "huggingface"):
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
......
......@@ -89,8 +89,12 @@ def get_layer_spec(is_vit, normalization) -> ModuleSpec:
)
def get_layer_spec_te(is_vit=False) -> ModuleSpec:
def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal
# Padding mask is needed for e.g. Context Parallel.
if padding:
assert not is_vit, "padding_causal mask not used with ViT"
attn_mask_type = AttnMaskType.padding_causal
mlp = get_norm_mlp_module_spec_te()
return ModuleSpec(
......
......@@ -30,7 +30,6 @@ def model_provider(
model: A multimodal model.
"""
args = get_args()
assert args.ckpt_format == 'torch', "Only ckpt-format torch is supported for VLM training currently."
assert args.encoder_pipeline_model_parallel_size <= 1, "LLaVA does not support pp>1 for encoder on it's own pipeline rank"
use_te = args.use_te
......@@ -77,21 +76,36 @@ def model_provider(
language_config = get_language_model_config(language_config)
if use_te:
# Padding mask needed for SP/CP.
padding = args.context_parallel_size > 1 and args.sequence_parallel
language_transformer_layer_spec = get_layer_spec_te(
is_vit=False
is_vit=False, padding=padding
) # TENorm detects LayerNorm/RMS automatically.
else:
language_transformer_layer_spec = get_layer_spec(
is_vit=False, normalization=language_config.normalization
)
vision_model_type = args.vision_model_type
vision_config = deepcopy(base_config)
vision_config = get_vision_model_config(
vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling
)
if vision_model_type.startswith("huggingface"):
assert args.encoder_tensor_model_parallel_size < 2, "Huggingface vision encoders do not support --encoder-tensor-model-parallel-size > 1"
assert args.encoder_pipeline_model_parallel_size == 0, "Huggingface vision encoders do not support --encoder-pipeline-model-parallel-size > 0"
assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel"
assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1"
assert args.vision_huggingface_model_name_or_path is not None, "Providing --vision-huggingface-model-name-or-path is necessary when using huggingface vision model"
vision_config.huggingface_model_name_or_path = args.vision_huggingface_model_name_or_path
from transformers import AutoConfig
huggingface_config = AutoConfig.from_pretrained(vision_config.huggingface_model_name_or_path)
vision_config.hidden_size = huggingface_config.hidden_size
vision_model_type = args.vision_model_type
if vision_model_type in ["clip", "siglip"]:
if vision_model_type in ["clip", "siglip", "radio"]:
if use_te:
vision_transformer_layer_spec = get_layer_spec_te(
is_vit=True
......@@ -103,10 +117,24 @@ def model_provider(
elif vision_model_type == "internvit":
from nvlm.internvit import get_internvit_layer_spec
vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te)
elif vision_model_type.startswith("huggingface"):
vision_transformer_layer_spec = None
else:
raise RuntimeError("unsupported vision model type", vision_model_type)
vision_projection_config = deepcopy(base_config)
if base_config.language_model_type.startswith("huggingface"):
assert args.tensor_model_parallel_size == 1, "Huggingface models do not support --tensor-model-parallel-size > 1"
assert args.pipeline_model_parallel_size < 2, "Huggingface models do not support --pipeline-model-parallel-size > 1"
assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel"
assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1"
assert args.language_huggingface_model_name_or_path is not None, "Providing --language-huggingface-model-name-or-path is necessary when using huggingface language model"
language_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path
# Pass to vision projection config so can choose the correct ffn hidden size
vision_projection_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path
vision_projection_config = get_vision_projection_config(
vision_projection_config, language_config.hidden_size
)
......@@ -149,9 +177,18 @@ def model_provider(
vision_projection_config.recompute_method = None
vision_projection_config.recompute_num_layers = None
# TODO: Vision model and projection do not use SP/CP yet.
vision_config.sequence_parallel = False
vision_config.context_parallel_size = 1
vision_config.tp_comm_overlap = False
vision_projection_config.sequence_parallel = False
vision_projection_config.context_parallel_size = 1
vision_projection_config.tp_comm_overlap = False
tokenizer = get_tokenizer()
image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
assert image_token_index is not None, f"IMAGE_TOKEN={IMAGE_TOKEN} needs to be added using the --special-tokens arg."
tile_tags = _get_tile_tags(args, tokenizer)
......@@ -168,6 +205,7 @@ def model_provider(
vision_projection_type="mlp",
allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint,
parallel_output=parallel_output,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
language_position_embedding_type=args.position_embedding_type,
language_rotary_percent=args.rotary_percent,
pre_process=pre_process,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment