Commit 688448db authored by silencealiang's avatar silencealiang
Browse files

更新代码

parent a02a5490
Pipeline #2503 passed with stage
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import argparse
import os
import torch
def convert(output_path, tensor_parallel_size, use_te, version):
device = "cuda"
model = torch.hub.load('NVlabs/RADIO', 'radio_model', version=version, progress=True)
state_dict = model.state_dict()
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]
# Indices from mapping pytorch multihead attention to megatron.
kv_channels = 80
hidden_dim = 1280
num_heads = 16
indices = []
for i in range(num_heads):
lb = i * kv_channels
ub = (i + 1) * kv_channels
indices.append(torch.arange(lb, ub, dtype=torch.int))
indices.append(torch.arange(hidden_dim + lb, hidden_dim + ub, dtype=torch.int))
indices.append(torch.arange(2 * hidden_dim + lb, 2 * hidden_dim + ub, dtype=torch.int))
indices = torch.cat(indices)
for name, tensor in state_dict.items():
# Map parameter names to ones used in megatron.
new_name = ""
new_tensor = tensor
if new_tensor.dtype == torch.float16:
new_tensor = new_tensor.to(torch.float32)
# This is used for chunking some tensors to target tensor parallel size.
chunk_dim = None
if "summary_idxs" in name:
continue
elif "patch_generator" in name:
if "embedder" in name:
new_name = "embedder.weight"
chunk_dim = 0
elif "cls_token" in name:
new_name = "class_token"
elif "pos_embed" in name:
new_name = "position_embeddings"
elif "input_conditioner" in name:
continue
elif "blocks" in name:
layer_idx = name.split(".")[2]
base = f"decoder.layers.{layer_idx}"
if "attn.qkv.weight" in name:
new_name = f"{base}.self_attention.linear_qkv.weight"
new_tensor = new_tensor[indices]
chunk_dim = 0
elif "attn.qkv.bias" in name:
new_name = f"{base}.self_attention.linear_qkv.bias"
new_tensor = new_tensor[indices]
chunk_dim = 0
elif "attn.proj.weight" in name:
new_name = f"{base}.self_attention.linear_proj.weight"
chunk_dim = 1
elif "attn.proj.bias" in name:
new_name = f"{base}.self_attention.linear_proj.bias"
elif "norm1.weight" in name:
new_name = f"{base}.input_layernorm.weight"
if use_te:
new_name = f"{base}.self_attention.linear_qkv.layer_norm_weight"
elif "norm1.bias" in name:
new_name = f"{base}.input_layernorm.bias"
if use_te:
new_name = f"{base}.self_attention.linear_qkv.layer_norm_bias"
elif "mlp.fc1.weight" in name:
new_name = f"{base}.mlp.linear_fc1.weight"
chunk_dim = 0
elif "mlp.fc1.bias" in name:
new_name = f"{base}.mlp.linear_fc1.bias"
chunk_dim = 0
elif "mlp.fc2.weight" in name:
new_name = f"{base}.mlp.linear_fc2.weight"
chunk_dim = 1
elif "mlp.fc2.bias" in name:
new_name = f"{base}.mlp.linear_fc2.bias"
elif "norm2.weight" in name:
new_name = f"{base}.pre_mlp_layernorm.weight"
if use_te:
new_name = f"{base}.mlp.linear_fc1.layer_norm_weight"
elif "norm2.bias" in name:
new_name = f"{base}.pre_mlp_layernorm.bias"
if use_te:
new_name = f"{base}.mlp.linear_fc1.layer_norm_bias"
assert new_name != "", f"unexpected layer name {name}"
if chunk_dim is None:
new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
else:
new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)
for i in range(tensor_parallel_size):
# chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage.
new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()
# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
is_extra_state_layer = any([l in new_name for l in extra_state_layers])
if use_te and is_extra_state_layer:
layer = new_name.split(".")[-2]
if layer in extra_state_layers:
extra_state_name = (
new_name[: new_name.rfind(".") + 1] + "_extra_state"
) # Replace the weight name.
new_state_dicts[i]["model"][extra_state_name] = None
for i in range(tensor_parallel_size):
output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}")
os.makedirs(output_dir_tp)
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
torch.save(new_state_dicts[i], output_path_tp)
with open(os.path.join(output_path, "latest_checkpointed_iteration.txt"), "w") as f:
f.write("1")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
Convert RADIO weights to megatron format.
Example usage:
python radio_converter.py --output /some/output/folder --tensor-parallel-size 4
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--output", type=str, required=True, help="output directory for megatron state dict file(s)"
)
parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="model tensor parallel size"
)
parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine")
parser.add_argument("--version", type=str, default="radio_v2.5-h", help="Version of radio to load for conversion")
args = parser.parse_args()
convert(args.output, args.tensor_parallel_size, args.use_te, args.version)
print("done.")
...@@ -10,7 +10,9 @@ def add_multimodal_extra_args(parser): ...@@ -10,7 +10,9 @@ def add_multimodal_extra_args(parser):
group.add_argument('--freeze-LM', action='store_true', default=False) group.add_argument('--freeze-LM', action='store_true', default=False)
group.add_argument('--freeze-ViT', action='store_true', default=False) group.add_argument('--freeze-ViT', action='store_true', default=False)
group.add_argument('--language-model-type', type=str, required=True) group.add_argument('--language-model-type', type=str, required=True)
group.add_argument('--language-huggingface-model-name-or-path', type=str)
group.add_argument('--vision-model-type', type=str, default="clip") group.add_argument('--vision-model-type', type=str, default="clip")
group.add_argument('--vision-huggingface-model-name-or-path', type=str)
group.add_argument("--disable-vision-class-token", action="store_true", default=False) group.add_argument("--disable-vision-class-token", action="store_true", default=False)
group.add_argument( group.add_argument(
"--allow-missing-vision-projection-checkpoint", action="store_true", default=False "--allow-missing-vision-projection-checkpoint", action="store_true", default=False
...@@ -49,7 +51,7 @@ def add_multimodal_extra_args(parser): ...@@ -49,7 +51,7 @@ def add_multimodal_extra_args(parser):
group.add_argument( group.add_argument(
"--tokenizer-prompt-format", "--tokenizer-prompt-format",
type=str, type=str,
choices=["mistral", "llama3", "chatml", "nvlm-yi-34b", "qwen2p0", "qwen2p5"], choices=["mistral", "llama3", "llama3p1", "chatml", "nvlm-yi-34b", "qwen2p0", "qwen2p5"],
required=True, required=True,
help="Prompt format to use with the tokenizer.", help="Prompt format to use with the tokenizer.",
) )
...@@ -74,6 +76,14 @@ def add_multimodal_extra_args(parser): ...@@ -74,6 +76,14 @@ def add_multimodal_extra_args(parser):
group.add_argument( group.add_argument(
"--recompute-vision", action="store_true", default=False, help="Enable activation checkpointing in the vision model" "--recompute-vision", action="store_true", default=False, help="Enable activation checkpointing in the vision model"
) )
group.add_argument(
"--use-loss-scaling", action="store_true", default=False, help="Scale loss based on conversation turn length (in tokens)."
)
group.add_argument(
"--use-area-weighted-aspect-ratio", action="store_true", default=False,
help=(
"When --use-tiling is True, find the aspect ratio to use based on the original ",
"image aspect ratio and the area covered by the tiles.")
)
return parser return parser
...@@ -11,11 +11,10 @@ Additionally, InternViT introduces some unique features like Layer Scaling. ...@@ -11,11 +11,10 @@ Additionally, InternViT introduces some unique features like Layer Scaling.
Those code changes are gathered here. Those code changes are gathered here.
""" """
from functools import partial from functools import partial
from typing import Dict
import torch import torch
from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.utils import divide
from megatron.core.extensions.transformer_engine import ( from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear, TEColumnParallelLinear,
TEDotProductAttention, TEDotProductAttention,
...@@ -92,21 +91,28 @@ class InternViTRMSNorm(MegatronModule): ...@@ -92,21 +91,28 @@ class InternViTRMSNorm(MegatronModule):
return output return output
def _gather_var(self, input_, max_dim, valid_ranks=6): def _gather_var(self, input_, max_dim):
"""Compute statistic across the non-dummy heads.""" """Compute statistic across the non-dummy heads."""
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
assert world_size == 8, "tested only with TP=8"
# Size and dimension. # Size and dimension.
last_dim = input_.dim() - 1 last_dim = input_.dim() - 1
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
if rank < valid_ranks: # Ranks 0-5 have 24 non-dummy attention heads. num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
valid_ranks = 24 // num_attention_heads_per_partition
residual_heads = 25 % num_attention_heads_per_partition
if residual_heads == 0:
residual_heads = num_attention_heads_per_partition
max_dim = max_dim * residual_heads
if rank < valid_ranks: # Ranks without any dummy attention heads.
var = input_.sum(-1, keepdim=True) var = input_.sum(-1, keepdim=True)
elif rank == valid_ranks: # Rank 6 has 1 non-dummy attention head. elif rank == valid_ranks: # The only rank which may contain 'residual_heads' dummy attention heads.
var = input_[..., :max_dim].sum(-1, keepdim=True) var = input_[..., :max_dim].sum(-1, keepdim=True)
else: else:
var = input_.sum(-1, keepdim=True) * 0.0 # Zero-out the dummy heads. var = input_.sum(-1, keepdim=True) * 0.0 # All heads in these ranks are dummy heads: Zero-out.
tensor_list = [torch.empty_like(var) for _ in range(world_size)] tensor_list = [torch.empty_like(var) for _ in range(world_size)]
tensor_list[rank] = var tensor_list[rank] = var
......
File mode changed from 100644 to 100755
...@@ -101,6 +101,7 @@ OPTIONS=" \ ...@@ -101,6 +101,7 @@ OPTIONS=" \
--init-method-std 0.014 \ --init-method-std 0.014 \
--attention-dropout ${AD} \ --attention-dropout ${AD} \
--hidden-dropout ${HD} \ --hidden-dropout ${HD} \
--untie-embeddings-and-output-weights \
--eod-mask-loss \ --eod-mask-loss \
--bf16 \ --bf16 \
--tensorboard-dir=${TENSORBOARD_DIR} \ --tensorboard-dir=${TENSORBOARD_DIR} \
......
#!/bin/bash
export NCCL_IB_SL=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_APPLY_QK_LAYER_SCALING=0
export TOKENIZERS_PARALLELISM="false"
INPUT_IMAGE_PATH="placeholder"
GROUNDTRUTH_PATH="placeholder"
while [[ $# -gt 0 ]]; do
case $1 in
--input-image-path)
INPUT_IMAGE_PATH="$2"
shift
shift
;;
--input-metadata-path)
INPUT_METADATA_PATH="$2"
shift
shift
;;
--num-frames)
NUM_FRAMES="$2"
shift
shift
;;
-g|--groundtruth-path)
GROUNDTRUTH_PATH="$2"
shift
shift
;;
-o|--output-path)
OUTPUT_PATH="$2"
shift
shift
;;
-m|--model-path)
MODEL_PATH="$2"
shift
shift
;;
--task)
TASK="$2"
shift
shift
;;
-g|--gt-path)
GROUNDTRUTH_PATH="$2"
shift
shift
;;
-*|--*)
echo "Invalid option $1"
exit 1
;;
esac
done
# Please modify these as needed.
NUM_PARTITIONS=0
START=0
END=0
SEQ_LEN=256
DECODER_SEQ_LEN=16384
EXTRA_ARGS=" --pixel-shuffle"
for PARTITION_ID in $( eval echo {$START..$END} )
do
torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \
--attention-softmax-in-fp32 \
--transformer-impl transformer_engine \
--use-te \
--use-checkpoint-args \
--normalization RMSNorm \
--norm-epsilon 1e-06 \
--language-model-type=qwen2.5_7B \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--position-embedding-type rope \
--rotary-percent 1.0 \
--rotary-base 1000000 \
--swiglu \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--tensor-model-parallel-size 4 \
--pipeline-model-parallel-size 1 \
--group-query-attention \
--num-query-groups 4 \
--num-layers 28 \
--hidden-size 3584 \
--ffn-hidden-size 18944 \
--add-qkv-bias \
--num-attention-heads 28 \
--max-position-embeddings 32768 \
--no-masked-softmax-fusion \
--load ${MODEL_PATH} \
--tokenizer-type MultimodalTokenizer \
--tokenizer-model Qwen/Qwen2.5-7B-Instruct \
--tokenizer-prompt-format qwen2p5 \
--bf16 \
--micro-batch-size 1 \
--seq-length ${SEQ_LEN} \
--decoder-seq-length ${DECODER_SEQ_LEN} \
--out-seq-length 128 \
--temperature 1.0 \
--img-h 448 \
--img-w 448 \
--patch-dim 14 \
--seed 153 \
--top_k 1 \
--no-load-rng \
--no-load-optim \
--input-image-path ${INPUT_IMAGE_PATH} \
--num-partitions ${NUM_PARTITIONS} \
--partition-id ${PARTITION_ID} \
--output-path ${OUTPUT_PATH} \
--gt-path ${GROUNDTRUTH_PATH} \
--task ${TASK} \
${EXTRA_ARGS} \
--special-tokens "<image>" "<img>" "</img>" \
--vision-model-type internvit \
--num-frames ${NUM_FRAMES} \
--ckpt-format torch
done
...@@ -107,6 +107,7 @@ OPTIONS=" \ ...@@ -107,6 +107,7 @@ OPTIONS=" \
--init-method-std 0.014 \ --init-method-std 0.014 \
--attention-dropout ${AD} \ --attention-dropout ${AD} \
--hidden-dropout ${HD} \ --hidden-dropout ${HD} \
--untie-embeddings-and-output-weights \
--eod-mask-loss \ --eod-mask-loss \
--bf16 \ --bf16 \
--tensorboard-dir=${TENSORBOARD_DIR} \ --tensorboard-dir=${TENSORBOARD_DIR} \
......
File mode changed from 100644 to 100755
#!/bin/bash
# Your SBATCH commands here if using SLURM.
# Please launch this script from megatron-lm root.
# Train a multimodal model.
export NCCL_IB_SL=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_ALGO=^NVLS
export TOKENIZERS_PARALLELISM=false
USER=$SLURM_JOB_USER
# Auto-detect batch or interactive mode.
which srun
BATCH=$((1-$?))
DEBUG=0
if [[ $BATCH -eq 0 ]]; then
DATETIME=`date +'%y-%m-%d-%H-%M-%S'`
MODEL_NAME="qwen2.5-7B-internvit-video-sft-nvlm-${DATETIME}"
else
MODEL_NAME="qwen2.5-7B-internvitp-video-sft-nvlm"
DEBUG=0
fi
WORKSPACE="<some dir>"
SOURCE=`pwd`
OUTPUT_BASE="${WORKSPACE}/output"
OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}"
FINETUNE_DIR="${OUTPUT}/checkpoints"
LOGS_DIR="${OUTPUT}/logs"
TENSORBOARD_DIR="${OUTPUT}/tensorboard"
# From pretraining. The pretraining checkpoint should have tensor parallel size to 4.
LOAD_NAME="mcore-qwen2p5-7b-internvit-tp4"
CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints"
DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/sft_blend.yaml"
if [[ $DEBUG -eq 1 ]]; then
MBZ=1
BZ=1
NW=0
AD=0.0
HD=0.0
LI=1
# This is just for interactive testing purposes. Do not use for proper training.
EXTRA_ARGS="--freeze-LM"
ALLOW_NONDETERMINISTIC=1
else
MBZ=1
BZ=256
NW=8
AD=0.0
HD=0.0
LI=5
EXTRA_ARGS=""
ALLOW_NONDETERMINISTIC=1
fi
USE_TILING=1
SEQ_LEN=1024
DECODER_SEQ_LEN=16384
MAX_POS_EMBED=32768
TRAIN_SAMPLES=6602173
WARMUP_SAMPLES=198065
if [[ $BATCH -eq 0 ]]; then
# Runs out of GPU memory in interactive memory without this.
EXTRA_ARGS+="--freeze-LM"
fi
if [[ $USE_TILING -eq 1 ]]; then
EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 12 --use-thumbnail"
SEQ_LEN=256
fi
OPTIONS=" \
--swiglu \
--use-distributed-optimizer \
--num-workers ${NW} \
--num-layers 28 \
--hidden-size 3584 \
--norm-epsilon 1e-06 \
--normalization RMSNorm \
--num-attention-heads 28 \
--exit-duration-in-mins 110 \
--group-query-attention \
--num-query-groups 4 \
--ffn-hidden-size 18944 \
--add-qkv-bias \
--seq-length ${SEQ_LEN} \
--decoder-seq-length ${DECODER_SEQ_LEN} \
--max-position-embeddings ${MAX_POS_EMBED} \
--dataloader-seq-length ${DECODER_SEQ_LEN} \
--tokenizer-type MultimodalTokenizer \
--tokenizer-model Qwen/Qwen2.5-7B-Instruct \
--tokenizer-prompt-format qwen2p5 \
--pixel-shuffle \
--position-embedding-type rope \
--rotary-percent 1.0 \
--rotary-base 1000000 \
--disable-bias-linear \
--pipeline-model-parallel-size 1 \
--tensor-model-parallel-size 4 \
--language-model-type qwen2.5_7B \
--vision-model-type internvit \
--micro-batch-size ${MBZ} \
--global-batch-size ${BZ} \
--lr 2e-6 \
--min-lr 2.5e-7 \
--train-samples ${TRAIN_SAMPLES} \
--lr-warmup-samples ${WARMUP_SAMPLES} \
--lr-decay-style cosine \
--clip-grad 10 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.014 \
--attention-dropout ${AD} \
--hidden-dropout ${HD} \
--eod-mask-loss \
--bf16 \
--tensorboard-dir ${TENSORBOARD_DIR} \
--img-h 448 \
--img-w 448 \
--patch-dim 14 \
--data-path ${DATA_TRAIN} \
--dataloader-type external \
--split 100,0,0 \
--prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \
--log-interval ${LI} \
--save-interval 500 \
--eval-interval 500 \
--eval-iters 10 \
--log-params-norm \
--log-num-zeros-in-grad \
${EXTRA_ARGS} \
--save ${FINETUNE_DIR} \
--load ${FINETUNE_DIR} \
--pretrained-checkpoint ${CHECKPOINT_DIR} \
--distributed-timeout-minutes 60 \
--allow-missing-vision-projection-checkpoint \
--dataloader-save ${FINETUNE_DIR}/dataloader \
--disable-vision-class-token \
--use-te \
--ckpt-format torch \
--num-frames 32 \
--use-checkpoint-args \
--image-tag-type internvl \
--recompute-granularity full \
--recompute-method block \
--recompute-num-layers 28 \
--recompute-vision \
"
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC}
export NVTE_APPLY_QK_LAYER_SCALING=0
# Interactive or batch mode
if [[ $BATCH -eq 0 ]]; then
torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS}
else
run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}"
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
srun -l --verbose \
--container-image <path to docker image> \
--container-mounts "<some mount>" \
--output=${LOGS_DIR}/%x_%j_$DATETIME.log \
sh -c "${run_cmd}"
set +x
fi
...@@ -25,6 +25,18 @@ from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings ...@@ -25,6 +25,18 @@ from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.inference.text_generation.api import generate_and_post_process from megatron.inference.text_generation.api import generate_and_post_process
from megatron.inference.text_generation.forward_step import ForwardStep from megatron.inference.text_generation.forward_step import ForwardStep
from megatron.inference.text_generation.communication import broadcast_int_list from megatron.inference.text_generation.communication import broadcast_int_list
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.inference_request import InferenceRequest, VLMInferenceRequest
from megatron.core.inference.text_generation_controllers.vlm_text_generation_controller import (
VLMTextGenerationController,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.inference.model_inference_wrappers.multimodal.vlm_inference_wrapper import (
VLMInferenceWrapper,
)
from megatron.training import get_args, get_model, get_tokenizer, print_rank_0 from megatron.training import get_args, get_model, get_tokenizer, print_rank_0
from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron from megatron.training.initialize import initialize_megatron
...@@ -60,6 +72,8 @@ def add_text_generation_args(parser): ...@@ -60,6 +72,8 @@ def add_text_generation_args(parser):
"OCRBench", "OCRBench",
"MathVista", "MathVista",
"AI2D", "AI2D",
"InfoVQA",
"SPDocVQA",
], ],
help="Generation task to run", help="Generation task to run",
) )
...@@ -68,6 +82,8 @@ def add_text_generation_args(parser): ...@@ -68,6 +82,8 @@ def add_text_generation_args(parser):
) )
group.add_argument("--config-path", type=str, help="Evaluation config file to use.") group.add_argument("--config-path", type=str, help="Evaluation config file to use.")
group.add_argument("--use-mcore-inference", action="store_true", default=False, help="Use the MCore inference API")
# Add common multimodal arguments needed for e.g. building the model. # Add common multimodal arguments needed for e.g. building the model.
parser = add_multimodal_extra_args(parser) parser = add_multimodal_extra_args(parser)
...@@ -153,15 +169,61 @@ def generate_samples(model, config: EvaluationConfig, print_output): ...@@ -153,15 +169,61 @@ def generate_samples(model, config: EvaluationConfig, print_output):
args.use_tile_tags, args.use_tile_tags,
) )
if args.use_mcore_inference:
inference_wrapper_config = InferenceWrapperConfig(
hidden_size=args.hidden_size,
inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold,
fp32_residual_connection=args.fp32_residual_connection,
params_dtype=args.params_dtype,
padded_vocab_size=args.padded_vocab_size,
)
inference_wrapped_model = VLMInferenceWrapper(model, inference_wrapper_config)
tokenizer = get_tokenizer()
controller = VLMTextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
)
inference_engine = MCoreEngine(
controller, max_batch_size=1, random_seed=args.seed
)
sampling_params = SamplingParams(
temperature=config.temperature,
top_k=config.top_k,
top_p=config.top_p,
num_tokens_to_generate=config.out_seq_length,
)
for idx, (imgs, num_tiles, sample_id, question, answers, metadata) in enumerate(dataloader): for idx, (imgs, num_tiles, sample_id, question, answers, metadata) in enumerate(dataloader):
imgs = imgs.to("cuda") imgs = imgs.to("cuda")
num_tiles = num_tiles.to("cuda") num_tiles = num_tiles.to("cuda")
conv = get_conversation(config.task, question) conv = get_conversation(config.task, question)
if not args.use_mcore_inference:
forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles, args.decoder_seq_length) forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles, args.decoder_seq_length)
if is_first_rank(): if is_first_rank():
if args.use_mcore_inference:
inference_request = VLMInferenceRequest(
request_id=inference_engine.get_new_request_id(),
prompt=conv,
prompt_tokens=controller.tokenize_prompt(conv),
inference_parameters=sampling_params,
num_img_embeddings_per_tile=num_img_embeddings_per_tile,
imgs=imgs,
num_tiles=num_tiles,
decoder_seq_length=args.decoder_seq_length,
)
results: List[InferenceRequest] = inference_engine.generate(
inference_requests=[inference_request]
)
resp_sentences = [
tokenizer.detokenize(result.prompt_tokens) + result.generated_text
for result in results
]
else:
resp_sentences, _, _, _ = generate_and_post_process( resp_sentences, _, _, _ = generate_and_post_process(
model, model,
forward_step=forward_step, forward_step=forward_step,
...@@ -192,6 +254,8 @@ def generate_samples(model, config: EvaluationConfig, print_output): ...@@ -192,6 +254,8 @@ def generate_samples(model, config: EvaluationConfig, print_output):
"OCRBench", "OCRBench",
"MathVista", "MathVista",
"AI2D", "AI2D",
"InfoVQA",
"SPDocVQA",
): ):
output_name = "answer" output_name = "answer"
elif config.task in ("MMMU"): elif config.task in ("MMMU"):
...@@ -220,6 +284,8 @@ def generate_samples(model, config: EvaluationConfig, print_output): ...@@ -220,6 +284,8 @@ def generate_samples(model, config: EvaluationConfig, print_output):
"OCRBench", "OCRBench",
"MathVista", "MathVista",
"AI2D", "AI2D",
"InfoVQA",
"SPDocVQA",
): ):
if isinstance(answers, str): if isinstance(answers, str):
answers = [answers] answers = [answers]
...@@ -238,6 +304,21 @@ def generate_samples(model, config: EvaluationConfig, print_output): ...@@ -238,6 +304,21 @@ def generate_samples(model, config: EvaluationConfig, print_output):
yield output yield output
idx += 1 idx += 1
else:
if args.use_mcore_inference:
inference_request = VLMInferenceRequest(
request_id=inference_engine.get_new_request_id(),
prompt=conv,
prompt_tokens=controller.tokenize_prompt(conv),
inference_parameters=sampling_params,
num_img_embeddings_per_tile=num_img_embeddings_per_tile,
imgs=imgs,
num_tiles=num_tiles,
decoder_seq_length=args.decoder_seq_length,
)
inference_engine.generate(
inference_requests=[inference_request]
)
else: else:
generate_and_post_process( generate_and_post_process(
model, forward_step=forward_step, detokenize_segments=False, data_parallel=True model, forward_step=forward_step, detokenize_segments=False, data_parallel=True
...@@ -310,7 +391,6 @@ def generate_and_write_samples(model, config, print_output=True): ...@@ -310,7 +391,6 @@ def generate_and_write_samples(model, config, print_output=True):
if is_first_rank(): if is_first_rank():
output_file.close() output_file.close()
class VLMForwardStep(ForwardStep): class VLMForwardStep(ForwardStep):
"""Inference forward step for a multimodal model.""" """Inference forward step for a multimodal model."""
...@@ -411,7 +491,7 @@ def get_conversation(task, question): ...@@ -411,7 +491,7 @@ def get_conversation(task, question):
"content": f"{IMAGE_TOKEN}\nProvide a one-sentence caption for provided image.", "content": f"{IMAGE_TOKEN}\nProvide a one-sentence caption for provided image.",
}, },
] ]
elif task in ("TextVQA", "VQAv2", "ChartQA"): elif task in ("TextVQA", "VQAv2", "ChartQA", "InfoVQA", "SPDocVQA"):
conversation = [ conversation = [
{"role": "system", "content": "Answer the questions."}, {"role": "system", "content": "Answer the questions."},
{ {
...@@ -443,7 +523,7 @@ def get_conversation(task, question): ...@@ -443,7 +523,7 @@ def get_conversation(task, question):
conversation = [ conversation = [
{"role": "system", "content": "Answer the questions."}, {"role": "system", "content": "Answer the questions."},
{"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"}, {"role": "user", "content": f"{IMAGE_TOKEN}\n{q}"},
] ]
return conversation return conversation
...@@ -451,7 +531,7 @@ def get_conversation(task, question): ...@@ -451,7 +531,7 @@ def get_conversation(task, question):
def get_prompt_and_generated(prompt_and_generation, prompt_format): def get_prompt_and_generated(prompt_and_generation, prompt_format):
"""Strip prompt and other unnecessary text from generation.""" """Strip prompt and other unnecessary text from generation."""
if prompt_format == "llama3": if prompt_format in ("llama3", "llama3p1"):
splitted = prompt_and_generation.split("<|start_header_id|>assistant<|end_header_id|>\n\n") splitted = prompt_and_generation.split("<|start_header_id|>assistant<|end_header_id|>\n\n")
prompt = splitted[0] prompt = splitted[0]
generated = splitted[1] generated = splitted[1]
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Pretrain or SFT multimodal.""" """Pretrain or SFT multimodal."""
import math
import os import os
import sys import sys
from functools import partial from functools import partial
...@@ -17,6 +18,7 @@ from multimodal_args import add_multimodal_extra_args ...@@ -17,6 +18,7 @@ from multimodal_args import add_multimodal_extra_args
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.models.multimodal import context_parallel
from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, LLaVAModel from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, LLaVAModel
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
...@@ -25,10 +27,10 @@ from megatron.core.parallel_state import ( ...@@ -25,10 +27,10 @@ from megatron.core.parallel_state import (
is_pipeline_last_stage, is_pipeline_last_stage,
) )
from megatron.training import get_args, get_timers, get_tokenizer, pretrain from megatron.training import get_args, get_timers, get_tokenizer, pretrain
from megatron.training.utils import is_last_rank from megatron.training.utils import is_last_rank, get_batch_on_this_cp_rank
def get_batch(data_iterator): def get_batch(data_iterator, image_token_index, img_seq_len):
"""Generate a batch """Generate a batch
Note: attn_mask_type in layer_specs.py sets the attention mask. Attention mask is None here. Note: attn_mask_type in layer_specs.py sets the attention mask. Attention mask is None here.
...@@ -66,9 +68,17 @@ def get_batch(data_iterator): ...@@ -66,9 +68,17 @@ def get_batch(data_iterator):
cu_lengths = tensor_parallel.broadcast_data(["cu_lengths"], data, torch.int32)["cu_lengths"] cu_lengths = tensor_parallel.broadcast_data(["cu_lengths"], data, torch.int32)["cu_lengths"]
max_lengths = tensor_parallel.broadcast_data(["max_lengths"], data, torch.int32)["max_lengths"] max_lengths = tensor_parallel.broadcast_data(["max_lengths"], data, torch.int32)["max_lengths"]
# No image input (text-only sample) if the dataloader produced a dummy image. # No image input (text-only sample) if the dataloader returned a size 1 image.
if imgs.shape == torch.Size([1, 1]): if imgs.shape == torch.Size([1, 1]):
# FIXME: text-only data can cause a hang if the vision model is own its own pipeline rank and --freeze-ViT is enabled. # FSDP can hang with text-only samples. A workaround is to run a valid dummy image through the vision
# model and then add image embeddings with a zero multiplier.
if args.use_torch_fsdp2:
imgs = torch.zeros((1, 3, args.img_h, args.img_w), dtype=torch.float32, device=data_text.device)
num_tiles = torch.tensor([], dtype=torch.int, device=data_text.device)
else:
# Similar workaround is not needed without FSDP and we can use an empty image.
# FIXME: text-only data can cause still cause a hang in the special case where
# the vision model is own its own pipeline rank and --freeze-ViT is enabled.
imgs = torch.tensor([], dtype=torch.float32, device=data_text.device) imgs = torch.tensor([], dtype=torch.float32, device=data_text.device)
num_tiles = torch.tensor([], dtype=torch.int, device=data_text.device) num_tiles = torch.tensor([], dtype=torch.int, device=data_text.device)
...@@ -109,6 +119,24 @@ def get_batch(data_iterator): ...@@ -109,6 +119,24 @@ def get_batch(data_iterator):
loss_mask, position_ids = get_ltor_masks_and_position_ids(tokens, labels, tokenizer.pad) loss_mask, position_ids = get_ltor_masks_and_position_ids(tokens, labels, tokenizer.pad)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# If context parallel is enabled, must shard inputs to CP ranks.
if args.context_parallel_size > 1 or args.sequence_parallel:
assert tokens.shape[0], "micro-batch-size > 1 not supported yet with CP"
num_image_tokens = torch.sum(tokens == image_token_index).item()
num_image_embeddings = num_image_tokens * img_seq_len - num_image_tokens
seq_len = text_length + num_image_embeddings
# CP expects sequence length is divisible by CP size so apply padding.
mp_padding_needed = context_parallel.get_padding(
seq_len, args.context_parallel_size,
args.tensor_model_parallel_size, args.sequence_parallel,
)
tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed)) for item in (tokens, position_ids, labels, loss_mask)]
# Get PackedSeqParams that indicate the amount of padding for TransformerEngine.
packed_seq_params = context_parallel.get_packed_seq_params(tokens, num_image_embeddings, mp_padding_needed, args.context_parallel_size, True)
return ( return (
tokens, tokens,
labels, labels,
...@@ -137,7 +165,82 @@ def get_ltor_masks_and_position_ids(input_ids, target, pad_token): ...@@ -137,7 +165,82 @@ def get_ltor_masks_and_position_ids(input_ids, target, pad_token):
return loss_mask, position_ids return loss_mask, position_ids
def get_mask_start_and_end_idx(arr):
"""
Returns a list of tuples holding the start and end index in arr of the non-zeros contiguuous
sub arrays.
For instance, if arr = [0, 1, 0, 0, 1, 1]
get_mask_start_and_end_idx(arr) = [(1, 1), (4, 5)]
such that arr[1:1+1] = [1] and arr[4:5+1] = [1, 1]
"""
mask = (arr != 0)
mask_int = mask.int()
diff = mask_int[1:] - mask_int[:-1]
start_indices = (diff == 1).nonzero(as_tuple=False).flatten() + 1
end_indices = (diff == -1).nonzero(as_tuple=False).flatten()
if len(mask)==0: return []
if mask[0]:
start_indices = torch.cat((torch.tensor([0], device=arr.device), start_indices))
if mask[-1]:
end_indices = torch.cat((end_indices, torch.tensor([len(arr) - 1], device=arr.device)))
sequences = list(zip(start_indices.tolist(), end_indices.tolist()))
return sequences
def scaled_loss_func(loss_mask, output_tensor):
"""
Scaled loss function
Scale the loss for each conversation turn using the formula:
1 / sum_j[ sqrt(length(loss_turn_j)) ] * sum_i[ sum(loss_turn_i) / sqrt(length(loss_turn_i)) ]
Where we use the loss mask to infer the start / end of the conversation turns.
"""
losses = output_tensor.float()
loss_list = []
num_valid_labels_list = []
for idx in range(losses.shape[0]):
loss_this_sample = losses[idx]
turn_start_end_list = get_mask_start_and_end_idx(loss_mask[idx])
for turn_start, turn_end in turn_start_end_list:
# compute loss for each turn
loss_this_turn = loss_this_sample[turn_start:turn_end+1].sum()
assert (1 - loss_mask)[idx][turn_start:turn_end+1].sum() < 1.0
num_valid_labels_this_turn = turn_end - turn_start + 1
loss_this_turn = loss_this_turn / num_valid_labels_this_turn
loss_list.append(loss_this_turn)
# append num of valid labels for each turn
num_valid_labels_list.append(num_valid_labels_this_turn)
base_num = sum([math.sqrt(each) for each in num_valid_labels_list])
for idx in range(len(loss_list)):
# normalize loss for each turn
loss_list[idx] = loss_list[idx] * math.sqrt(num_valid_labels_list[idx]) / base_num
total_loss = torch.stack(loss_list).sum()
total_tokens = torch.ones_like(total_loss)
loss = torch.cat([total_loss.view(1), total_tokens.view(1)])
reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
local_num_tokens = loss[1].clone().detach().to(torch.int)
return (
total_loss,
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
def loss_func(loss_mask, output_tensor): def loss_func(loss_mask, output_tensor):
args = get_args()
losses = output_tensor.float() losses = output_tensor.float()
loss_mask = loss_mask.contiguous().view(-1).float() loss_mask = loss_mask.contiguous().view(-1).float()
...@@ -146,12 +249,20 @@ def loss_func(loss_mask, output_tensor): ...@@ -146,12 +249,20 @@ def loss_func(loss_mask, output_tensor):
total_loss = torch.sum(losses.view(-1) * loss_mask) total_loss = torch.sum(losses.view(-1) * loss_mask)
loss = torch.cat([total_loss.view(1), total_tokens.view(1)]) loss = torch.cat([total_loss.view(1), total_tokens.view(1)])
if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
reporting_loss = loss.clone().detach() reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
local_num_tokens = loss[1].clone().detach().to(torch.int) local_num_tokens = loss[1].clone().detach().to(torch.int)
return (total_loss, local_num_tokens, {'lm loss': (reporting_loss[0], reporting_loss[1])}) # We multiply by context parallel size because later there will be a divide by CP(+DP) size.
return (
loss[0] * args.context_parallel_size,
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])}
)
def forward_step(data_iterator, model: LLaVAModel): def forward_step(data_iterator, model: LLaVAModel):
...@@ -178,7 +289,7 @@ def forward_step(data_iterator, model: LLaVAModel): ...@@ -178,7 +289,7 @@ def forward_step(data_iterator, model: LLaVAModel):
images, images,
num_image_tiles, num_image_tiles,
packed_seq_params, packed_seq_params,
) = get_batch(data_iterator) ) = get_batch(data_iterator, model.module.module.image_token_index, model.module.module.img_seq_len)
timers('batch-generator').stop() timers('batch-generator').stop()
output_tensor, loss_mask = model( output_tensor, loss_mask = model(
...@@ -191,8 +302,13 @@ def forward_step(data_iterator, model: LLaVAModel): ...@@ -191,8 +302,13 @@ def forward_step(data_iterator, model: LLaVAModel):
num_image_tiles=num_image_tiles, num_image_tiles=num_image_tiles,
packed_seq_params=packed_seq_params, packed_seq_params=packed_seq_params,
) )
args = get_args()
if args.use_loss_scaling:
loss_function = partial(scaled_loss_func, loss_mask)
else:
loss_function = partial(loss_func, loss_mask)
return output_tensor, partial(loss_func, loss_mask) return output_tensor, loss_function
def llava_embedding_ranks(pp_ranks): def llava_embedding_ranks(pp_ranks):
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment