diff --git a/CHANGELOG.md b/CHANGELOG.md index 7960574199984cfc24f9f1c1b853338f0fbd10f6..1b59cd874a6b7f63c11b2479cdef98713f4cd3da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,122 +1,132 @@ -# Changelog - -## NVIDIA Megatron Core 0.9.0 - -- Uneven pipeline parallelism - - Enable pipeline parallelism where first and last ranks have fewer transformer layers than the intermediate ranks -- Per layer CUDAGraph support for GPT training with Transformer Engine modules -- Enable different TP sizes for the vision encoder -- Enable pipeline parallelism for T5 & Llava models -- Support multi-tile multi-image input in Llava models -- MoE - - FP8 support - - Runtime upcycling support - - Dispatcher implementation optimizations - - Shared expert support with overlapping optimizations - - Qwen Model support -- Known Issues - - When using sequence parallel, during the transformer block forward pass, dropout is not using the appropriate rng context. - - -## NVIDIA Megatron Core 0.8.0 - -- Multimodal - - Added initial support for training vision language models using the LLaVA architecture - - Added initial support for inference with multimodal inputs - - End-to-end multimodal example from data collection to training to evaluation is provided in examples/multimodal -- MoE - - Context Parallel support. - - Distributed checkpoint support for grouped GEMM. -- Mamba - -## NVIDIA Megatron Core 0.7.0 - -- MoE - - Token drop support - - Several efficiency optimizations - - Improved model parallelism - - Memory optimizations -- Distributed checkpointing - - Enabled for Retro - - Asynchronous checkpoint saving -- Several minor bug fixes, speed improvements, and memory optimizations - -## NVIDIA Megatron Core 0.6.0 - -- MoE (Mixture of Experts) - - Performance optimization - - Communication optimization for multi GPU and Single GPU - - 23% improvement (323 TFLOPS/GPU) over MCore 0.5.0 on Mixtral with Hopper BF16 - - GroupedMLP enhancement for Hopper - - DP Overlapping. Support overlapping computation with gradient reduction and parameter gathering. - - All-to-All based Token Dispatcher - - Layer-wise logging for load balancing loss. - - Improved expert parallel support including distributed optimizer. -- Distributed optimizer -- RETRO - - Data processing -- BERT - - Distributed checkpointing -- Dist checkpointing - - PyTorch native distributed backend - - Improved saving/loading speed -- TensorRT-LLM Export - - Integration with TensorRT Model Optimizer Post-training quantization (PTQ) - - Text generation driver to perform PTQ in Megatron-LM - - Llama2 and Nemotron3-8b examples to use TensorRT-LLM unified build API to build engine after training. -- Several minor enhancements, bug fixes, and documentation updates - -## NVIDIA Megatron Core 0.5.0 - -### Key Features and Enhancements - -Megatron core documentation is now [live!](https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start) - -### Model Features - -- MoE (Mixture of Experts) - - Support for Z-loss, Load balancing and Sinkhorn - - Layer and communications refactor - - Richer parallelism mappings and EP can be combined with other model parallel techniques for larger MoE variants, e.g. EP + TP + DP + SP + PP - - Token dropless architecture with Top-K routing - - Performance optimization with with GroupedGEMM when number of local experts is > 1 - - Distributed checkpointing -- Interleaved rotary embedding - -### Datasets - -- Masked WordPiece datasets for BERT and T5 -- Raw and mock datasets - -### Parallelism - -### Performance - -- Activation offloading to CPU -- Rope and Swiglu fusion -- Sliding window attention (via Transformer Engine) - -### General Improvements - -- Timers - -## NVIDIA Megatron Core 0.4.0 - -### Key Features and Enhancements - -#### Models - -- BERT -- RETRO -- T5 - -#### Parallelism - -- Mixture of Experts support for GPT -- Model parallel efficient Distributed Data Parallel (DDP) -- Context Parallel (2D Tensor Parallel) support - -#### Datasets - -- GPT Dataset -- Blended Dataset +# Changelog + +## NVIDIA Megatron Core 0.10.0 + +- Adding MLA to MCore +- Enable FP8 for GroupedMLP +- MoE Parallel Folding +- Enhance MoE Architecture: Support MoE Layer Frequency Patterns and Configurable MoE FFN Hidden Size +- Multimodal: NVLM training and evaluation support in MCore +- Mamba Hybrid + - Increase performance and reduce memory footprint of Triton language/compiler distributed caching + - Add more unit testing and fix bugs + +## NVIDIA Megatron Core 0.9.0 + +- Uneven pipeline parallelism + - Enable pipeline parallelism where first and last ranks have fewer transformer layers than the intermediate ranks +- Per layer CUDAGraph support for GPT training with Transformer Engine modules +- Enable different TP sizes for the vision encoder +- Enable pipeline parallelism for T5 & Llava models +- Support multi-tile multi-image input in Llava models +- MoE + - FP8 support + - Runtime upcycling support + - Dispatcher implementation optimizations + - Shared expert support with overlapping optimizations + - Qwen Model support +- Known Issues + - When using sequence parallel, during the transformer block forward pass, dropout is not using the appropriate rng context. + +## NVIDIA Megatron Core 0.8.0 + +- Multimodal + - Added initial support for training vision language models using the LLaVA architecture + - Added initial support for inference with multimodal inputs + - End-to-end multimodal example from data collection to training to evaluation is provided in examples/multimodal +- MoE + - Context Parallel support. + - Distributed checkpoint support for grouped GEMM. +- Mamba + +## NVIDIA Megatron Core 0.7.0 + +- MoE + - Token drop support + - Several efficiency optimizations + - Improved model parallelism + - Memory optimizations +- Distributed checkpointing + - Enabled for Retro + - Asynchronous checkpoint saving +- Several minor bug fixes, speed improvements, and memory optimizations + +## NVIDIA Megatron Core 0.6.0 + +- MoE (Mixture of Experts) + - Performance optimization + - Communication optimization for multi GPU and Single GPU + - 23% improvement (323 TFLOPS/GPU) over MCore 0.5.0 on Mixtral with Hopper BF16 + - GroupedMLP enhancement for Hopper + - DP Overlapping. Support overlapping computation with gradient reduction and parameter gathering. + - All-to-All based Token Dispatcher + - Layer-wise logging for load balancing loss. + - Improved expert parallel support including distributed optimizer. +- Distributed optimizer +- RETRO + - Data processing +- BERT + - Distributed checkpointing +- Dist checkpointing + - PyTorch native distributed backend + - Improved saving/loading speed +- TensorRT-LLM Export + - Integration with TensorRT Model Optimizer Post-training quantization (PTQ) + - Text generation driver to perform PTQ in Megatron-LM + - Llama2 and Nemotron3-8b examples to use TensorRT-LLM unified build API to build engine after training. +- Several minor enhancements, bug fixes, and documentation updates + +## NVIDIA Megatron Core 0.5.0 + +### Key Features and Enhancements + +Megatron core documentation is now [live!](https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start) + +### Model Features + +- MoE (Mixture of Experts) + - Support for Z-loss, Load balancing and Sinkhorn + - Layer and communications refactor + - Richer parallelism mappings and EP can be combined with other model parallel techniques for larger MoE variants, e.g. EP + TP + DP + SP + PP + - Token dropless architecture with Top-K routing + - Performance optimization with with GroupedGEMM when number of local experts is > 1 + - Distributed checkpointing +- Interleaved rotary embedding + +### Datasets + +- Masked WordPiece datasets for BERT and T5 +- Raw and mock datasets + +### Parallelism + +### Performance + +- Activation offloading to CPU +- Rope and Swiglu fusion +- Sliding window attention (via Transformer Engine) + +### General Improvements + +- Timers + +## NVIDIA Megatron Core 0.4.0 + +### Key Features and Enhancements + +#### Models + +- BERT +- RETRO +- T5 + +#### Parallelism + +- Mixture of Experts support for GPT +- Model parallel efficient Distributed Data Parallel (DDP) +- Context Parallel (2D Tensor Parallel) support + +#### Datasets + +- GPT Dataset +- Blended Dataset diff --git a/CODEOWNERS b/CODEOWNERS index e89c62b06e917127c779e51f878faeb7d43bfbda..5a01a3cc111b3b93decd07f26cafd44910e61b77 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,49 +1,49 @@ -[Core-ADLR] @mcore-reviewers/core-adlr -megatron/core/ - -[Core-NeMo] @mcore-reviewers/core-nemo -megatron/core/ - -^[Core-MLPerf] @mcore-reviewers/mlperf -megatron/core/ - -[MoE-ADLR] @mcore-reviewers/moe-adlr -megatron/core/transformer/moe/ - -[MoE-Moe] @mcore-reviewers/moe-moe -megatron/core/transformer/moe/ - -[Datasets] @mcore-reviewers/datasets -megatron/core/datasets/ - -[BERT] @mcore-reviewers/bert -megatron/core/models/bert/ - -[GPT] @mcore-reviewers/gpt -megatron/core/models/gpt/ - -[Retro] @mcore-reviewers/retro -megatron/core/models/retro/ - -[Distributed Checkpointing] @mcore-reviewers/dist-checkpointing -megatron/core/dist_checkpointing/ - -[Distributed Optimizer] @mcore-reviewers/dist-optimizer -megatron/core/optimizer/distrib_optimizer/ - -[Inference] @mcore-reviewers/inference -megatron/core/inference/ - -^[Quantization and Inference (QAT)] @mcore-reviewers/quantization-and-inference -megatron/core/inference/ - -; [Context Parallelism] @mcore-reviewers/context-parallelism -; - -[CI] @mcore-reviewers/ci -.gitlab/ -.github/ -.gitlab-ci.yml -Dockerfile.ci.lts -Dockerfile.ci.dev -tests/ +[Core-ADLR] @mcore-reviewers/core-adlr +megatron/core/ + +[Core-NeMo] @mcore-reviewers/core-nemo +megatron/core/ + +^[Core-MLPerf] @mcore-reviewers/mlperf +megatron/core/ + +[MoE-ADLR] @mcore-reviewers/moe-adlr +megatron/core/transformer/moe/ + +[MoE-Moe] @mcore-reviewers/moe-moe +megatron/core/transformer/moe/ + +[Datasets] @mcore-reviewers/datasets +megatron/core/datasets/ + +[BERT] @mcore-reviewers/bert +megatron/core/models/bert/ + +[GPT] @mcore-reviewers/gpt +megatron/core/models/gpt/ + +[Retro] @mcore-reviewers/retro +megatron/core/models/retro/ + +[Distributed Checkpointing] @mcore-reviewers/dist-checkpointing +megatron/core/dist_checkpointing/ + +[Distributed Optimizer] @mcore-reviewers/dist-optimizer +megatron/core/optimizer/distrib_optimizer/ + +[Inference] @mcore-reviewers/inference +megatron/core/inference/ + +^[Quantization and Inference (QAT)] @mcore-reviewers/quantization-and-inference +megatron/core/inference/ + +; [Context Parallelism] @mcore-reviewers/context-parallelism +; + +[CI][2] @mcore-reviewers/ci +.gitlab/ +.github/ +.gitlab-ci.yml +Dockerfile.ci.lts +Dockerfile.ci.dev +tests/ diff --git a/Dockerfile.ci.dev b/Dockerfile.ci.dev index c631282c2de3a15e80389182a8cb421e85464e24..074d2039f24f1d7a2775b109d5dc01d90d94d494 100644 --- a/Dockerfile.ci.dev +++ b/Dockerfile.ci.dev @@ -1,76 +1,84 @@ -# syntax=docker/dockerfile:1.3-labs - -ARG FROM_IMAGE_NAME -FROM $FROM_IMAGE_NAME as build_causal_conv1d -WORKDIR /opt -RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/Dao-AILab/causal-conv1d.git@v1.2.2.post1 - -FROM $FROM_IMAGE_NAME as build_grouped_gemm -WORKDIR /opt -RUN pip3 wheel -v git+https://github.com/fanshiqing/grouped_gemm@v1.1.2 - -FROM $FROM_IMAGE_NAME as build_mamba_ssm -WORKDIR /opt -RUN MAMBA_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/state-spaces/mamba.git@v2.2.0 - -FROM $FROM_IMAGE_NAME as main -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get update && \ - apt-get install -y --no-install-recommends gettext python3-venv && \ - apt-get clean && \ - python -m venv /opt/jet && \ - wget https://github.com/mikefarah/yq/releases/download/v4.44.1/yq_linux_amd64 -O /usr/local/bin/yq && \ - chmod a+x /usr/local/bin/yq - -COPY --from=build_causal_conv1d /opt/causal_conv1d-*.whl ./ -COPY --from=build_grouped_gemm /opt/grouped_gemm-*.whl ./ -COPY --from=build_mamba_ssm /opt/mamba_ssm-*.whl ./ - -RUN \ - --mount=type=bind,source=requirements,target=requirements \ - --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ - --mount=type=bind,source=setup.py,target=setup.py \ - --mount=type=bind,source=megatron/core/package_info.py,target=megatron/core/package_info.py \ - --mount=type=bind,source=megatron/core/README.md,target=megatron/core/README.md \ - --mount=type=bind,source=megatron/core/__init__.py,target=megatron/core/__init__.py <<"EOF" bash -ex - -pip install causal_conv1d-*.whl mamba_ssm-*.whl grouped_gemm-*.whl -PY_ENV=pytorch:24.07 pip install . -EOF - -# Since megatron does not have any dependencies (and isn't a dependency to any other package), we can install it separately to make everything a bit quicker -ARG MCORE_REPO -ARG MCORE_REF -ARG MCORE_BACKWARDS_REF -RUN <<"EOF" bash -exu -# Checkout latest -cd /opt -rm -rf /opt/megatron-lm; mkdir megatron-lm; cd megatron-lm -git init -git remote add origin ${MCORE_REPO} -git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' -git fetch origin $MCORE_REF -git checkout $MCORE_REF - -# Checkout backwards-ref -cd /opt -rm -rf /opt/megatron-lm-legacy; mkdir megatron-lm-legacy; cd megatron-lm-legacy -git init -git remote add origin ${MCORE_REPO} -git fetch origin $MCORE_BACKWARDS_REF -git checkout $MCORE_BACKWARDS_REF -rm -rf megatron; cp -a /opt/megatron-lm/megatron ./ -EOF - -RUN PY_ENV=pytorch:24.07 pip install -e /opt/megatron-lm -ENV PYTHONPATH="/opt/megatron-lm:$PYTHONPATH" - -##### For NVIDIANS only ##### -FROM main as jet -ARG CACHEBUST=0 -RUN --mount=type=secret,id=JET_INDEX_URLS \ - JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ - pip install jet-client jet-api --upgrade $JET_INDEX_URLS -ENV PATH="$PATH:/opt/jet/bin" +# syntax=docker/dockerfile:1.3-labs + +ARG FROM_IMAGE_NAME +FROM $FROM_IMAGE_NAME as build_causal_conv1d +WORKDIR /opt +RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/Dao-AILab/causal-conv1d.git@v1.2.2.post1 + +FROM $FROM_IMAGE_NAME as build_grouped_gemm +WORKDIR /opt +RUN pip3 wheel -v git+https://github.com/fanshiqing/grouped_gemm@v1.1.2 + +FROM $FROM_IMAGE_NAME as build_mamba_ssm +WORKDIR /opt +RUN git clone https://github.com/state-spaces/mamba.git && \ + cd mamba && \ + git checkout v2.2.0 && \ + sed -i "/triton/d" setup.py && \ + MAMBA_FORCE_BUILD=TRUE pip3 wheel -v . + +FROM $FROM_IMAGE_NAME as main +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install -y --no-install-recommends gettext python3-venv && \ + apt-get clean && \ + python -m venv /opt/jet && \ + wget https://github.com/mikefarah/yq/releases/download/v4.44.1/yq_linux_amd64 -O /usr/local/bin/yq && \ + chmod a+x /usr/local/bin/yq + +COPY --from=build_causal_conv1d /opt/causal_conv1d-*.whl ./ +COPY --from=build_grouped_gemm /opt/grouped_gemm-*.whl ./ +COPY --from=build_mamba_ssm /opt/mamba/mamba_ssm-*.whl ./ + +RUN \ + --mount=type=bind,source=requirements,target=requirements \ + --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ + --mount=type=bind,source=setup.py,target=setup.py \ + --mount=type=bind,source=megatron/core/package_info.py,target=megatron/core/package_info.py \ + --mount=type=bind,source=megatron/core/README.md,target=megatron/core/README.md \ + --mount=type=bind,source=megatron/core/requirements.txt,target=megatron/core/requirements.txt \ + --mount=type=bind,source=megatron/core/__init__.py,target=megatron/core/__init__.py <<"EOF" bash -ex + +pip install causal_conv1d-*.whl mamba_ssm-*.whl grouped_gemm-*.whl +PY_ENV=pytorch_24.10 pip install . +EOF + +# Since megatron does not have any dependencies (and isn't a dependency to any other package), we can install it separately to make everything a bit quicker +ARG MCORE_REPO +ARG MCORE_REF +ARG MCORE_BACKWARDS_REF +RUN <<"EOF" bash -exu +# Checkout latest +cd /opt +rm -rf /opt/megatron-lm; mkdir megatron-lm; cd megatron-lm +git init +git remote add origin ${MCORE_REPO} +git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' +git fetch origin $MCORE_REF +git checkout $MCORE_REF + +# Checkout backwards-ref +cd /opt +rm -rf /opt/megatron-lm-legacy; mkdir megatron-lm-legacy; cd megatron-lm-legacy +git init +git remote add origin ${MCORE_REPO} +git fetch origin $MCORE_BACKWARDS_REF +git checkout $MCORE_BACKWARDS_REF +rm -rf megatron; cp -a /opt/megatron-lm/megatron ./ +EOF + +RUN PY_ENV=pytorch_24.10 pip install -e /opt/megatron-lm +ENV PYTHONPATH="/opt/megatron-lm:$PYTHONPATH" + +##### For NVIDIANS only ##### +FROM main as jet +ARG CACHEBUST=0 +RUN --mount=type=secret,id=JET_INDEX_URLS \ + --mount=type=secret,id=LOGGER_INDEX_URL \ + LOGGER_INDEX_URL=$(cat /run/secrets/LOGGER_INDEX_URL) && \ + JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ + pip install "jet-client~=2.0" jet-api --upgrade $JET_INDEX_URLS && \ + pip install "one-logger" --upgrade $LOGGER_INDEX_URL +ENV PATH="$PATH:/opt/jet/bin" ### \ No newline at end of file diff --git a/Dockerfile.ci.lts b/Dockerfile.ci.lts index ea0cf31a0b405b20692280325653c4bde465cc40..a3d15e8d4801a59120898849b54832207c501b1a 100644 --- a/Dockerfile.ci.lts +++ b/Dockerfile.ci.lts @@ -1,77 +1,81 @@ -# syntax=docker/dockerfile:1.3-labs - -ARG FROM_IMAGE_NAME -FROM $FROM_IMAGE_NAME as build_causal_conv1d -WORKDIR /opt -RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/Dao-AILab/causal-conv1d.git@v1.2.2.post1 - -FROM $FROM_IMAGE_NAME as build_grouped_gemm -WORKDIR /opt -RUN pip3 wheel -v git+https://github.com/fanshiqing/grouped_gemm@v1.1.2 - -FROM $FROM_IMAGE_NAME as build_mamba_ssm -WORKDIR /opt -RUN MAMBA_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/state-spaces/mamba.git@v2.0.3 - -ARG FROM_IMAGE_NAME -FROM $FROM_IMAGE_NAME as main -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get update && \ - apt-get install -y --no-install-recommends gettext python3-venv && \ - apt-get clean && \ - python -m venv /opt/jet && \ - wget https://github.com/mikefarah/yq/releases/download/v4.44.1/yq_linux_amd64 -O /usr/local/bin/yq && \ - chmod a+x /usr/local/bin/yq - -COPY --from=build_causal_conv1d /opt/causal_conv1d-*.whl ./ -COPY --from=build_grouped_gemm /opt/grouped_gemm-*.whl ./ -COPY --from=build_mamba_ssm /opt/mamba_ssm-*.whl ./ - -RUN \ - --mount=type=bind,source=requirements,target=requirements \ - --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ - --mount=type=bind,source=setup.py,target=setup.py \ - --mount=type=bind,source=megatron/core/package_info.py,target=megatron/core/package_info.py \ - --mount=type=bind,source=megatron/core/README.md,target=megatron/core/README.md \ - --mount=type=bind,source=megatron/core/__init__.py,target=megatron/core/__init__.py <<"EOF" bash -ex - -pip install causal_conv1d-*.whl mamba_ssm-*.whl grouped_gemm-*.whl -PY_ENV=pytorch:24.07 pip install . -EOF - -# Since megatron does not have any dependencies (and isn't a dependency to any other package), we can install it separately to make everything a bit quicker -ARG MCORE_REPO -ARG MCORE_REF -ARG MCORE_BACKWARDS_REF -RUN <<"EOF" bash -exu -# Checkout latest -cd /opt -rm -rf /opt/megatron-lm; mkdir megatron-lm; cd megatron-lm -git init -git remote add origin ${MCORE_REPO} -git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' -git fetch origin $MCORE_REF -git checkout $MCORE_REF - -# Checkout backwards-ref -cd /opt -rm -rf /opt/megatron-lm-legacy; mkdir megatron-lm-legacy; cd megatron-lm-legacy -git init -git remote add origin ${MCORE_REPO} -git fetch origin $MCORE_BACKWARDS_REF -git checkout $MCORE_BACKWARDS_REF -rm -rf megatron; cp -a /opt/megatron-lm/megatron ./ -EOF - -RUN PY_ENV=pytorch:24.01 pip install -e /opt/megatron-lm -ENV PYTHONPATH="/opt/megatron-lm:$PYTHONPATH" - -##### For NVIDIANS only ##### -FROM main as jet -ARG CACHEBUST=0 -RUN --mount=type=secret,id=JET_INDEX_URLS \ - JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ - pip install jet-api jet-client --upgrade $JET_INDEX_URLS -ENV PATH="$PATH:/opt/jet/bin" +# syntax=docker/dockerfile:1.3-labs + +ARG FROM_IMAGE_NAME +FROM $FROM_IMAGE_NAME as build_causal_conv1d +WORKDIR /opt +RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/Dao-AILab/causal-conv1d.git@v1.2.2.post1 + +FROM $FROM_IMAGE_NAME as build_grouped_gemm +WORKDIR /opt +RUN pip3 wheel -v git+https://github.com/fanshiqing/grouped_gemm@v1.1.2 + +FROM $FROM_IMAGE_NAME as build_mamba_ssm +WORKDIR /opt +RUN MAMBA_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/state-spaces/mamba.git@v2.0.3 + +ARG FROM_IMAGE_NAME +FROM $FROM_IMAGE_NAME as main +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install -y --no-install-recommends gettext python3-venv && \ + apt-get clean && \ + python -m venv /opt/jet && \ + wget https://github.com/mikefarah/yq/releases/download/v4.44.1/yq_linux_amd64 -O /usr/local/bin/yq && \ + chmod a+x /usr/local/bin/yq + +COPY --from=build_causal_conv1d /opt/causal_conv1d-*.whl ./ +COPY --from=build_grouped_gemm /opt/grouped_gemm-*.whl ./ +COPY --from=build_mamba_ssm /opt/mamba_ssm-*.whl ./ + +RUN \ + --mount=type=bind,source=requirements,target=requirements \ + --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ + --mount=type=bind,source=setup.py,target=setup.py \ + --mount=type=bind,source=megatron/core/package_info.py,target=megatron/core/package_info.py \ + --mount=type=bind,source=megatron/core/README.md,target=megatron/core/README.md \ + --mount=type=bind,source=megatron/core/requirements.txt,target=megatron/core/requirements.txt \ + --mount=type=bind,source=megatron/core/__init__.py,target=megatron/core/__init__.py <<"EOF" bash -ex + +pip install causal_conv1d-*.whl mamba_ssm-*.whl grouped_gemm-*.whl +PY_ENV=pytorch_24.01 pip install . +EOF + +# Since megatron does not have any dependencies (and isn't a dependency to any other package), we can install it separately to make everything a bit quicker +ARG MCORE_REPO +ARG MCORE_REF +ARG MCORE_BACKWARDS_REF +RUN <<"EOF" bash -exu +# Checkout latest +cd /opt +rm -rf /opt/megatron-lm; mkdir megatron-lm; cd megatron-lm +git init +git remote add origin ${MCORE_REPO} +git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' +git fetch origin $MCORE_REF +git checkout $MCORE_REF + +# Checkout backwards-ref +cd /opt +rm -rf /opt/megatron-lm-legacy; mkdir megatron-lm-legacy; cd megatron-lm-legacy +git init +git remote add origin ${MCORE_REPO} +git fetch origin $MCORE_BACKWARDS_REF +git checkout $MCORE_BACKWARDS_REF +rm -rf megatron; cp -a /opt/megatron-lm/megatron ./ +EOF + +RUN PY_ENV=pytorch_24.01 pip install -e /opt/megatron-lm +ENV PYTHONPATH="/opt/megatron-lm:$PYTHONPATH" + +##### For NVIDIANS only ##### +FROM main as jet +ARG CACHEBUST=0 +RUN --mount=type=secret,id=JET_INDEX_URLS \ + --mount=type=secret,id=LOGGER_INDEX_URL \ + LOGGER_INDEX_URL=$(cat /run/secrets/LOGGER_INDEX_URL) && \ + JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ + pip install "jet-client~=2.0" jet-api --upgrade $JET_INDEX_URLS && \ + pip install "one-logger" --upgrade $LOGGER_INDEX_URL +ENV PATH="$PATH:/opt/jet/bin" ### \ No newline at end of file diff --git a/Dockerfile.linting b/Dockerfile.linting index ff1a28cefd824e06edd8ea37c130490d8277ca6d..608e2587eae7fc4e5c640813d6fe41136130cf6e 100644 --- a/Dockerfile.linting +++ b/Dockerfile.linting @@ -1,33 +1,34 @@ -# syntax=docker/dockerfile:experimental - -ARG FROM_IMAGE_NAME -FROM $FROM_IMAGE_NAME as main -ENV DEBIAN_FRONTEND=noninteractive - -RUN sed -i -e 's/^APT/# APT/' -e 's/^DPkg/# DPkg/' \ - /etc/apt/apt.conf.d/docker-clean - -RUN apt-get update && \ - apt-get install -y python3-venv && \ - apt-get clean && \ - python -m venv /opt/jet - -RUN pip3 install --no-cache-dir \ - black==24.4.2 \ - isort==5.13.2 \ - flake8==7.1.0 \ - pylint==3.2.6 \ - mypy - -COPY . /opt/megatron-lm - -WORKDIR /opt/megatron-lm - -##### For NVIDIANS only ##### -FROM main as jet -ARG CACHEBUST=0 -RUN --mount=type=secret,id=JET_INDEX_URLS \ - JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ - pip install jet-client jet-api --upgrade $JET_INDEX_URLS -ENV PATH="$PATH:/opt/jet/bin" +# syntax=docker/dockerfile:experimental + +ARG FROM_IMAGE_NAME +FROM $FROM_IMAGE_NAME as main +ENV DEBIAN_FRONTEND=noninteractive + +RUN sed -i -e 's/^APT/# APT/' -e 's/^DPkg/# DPkg/' \ + /etc/apt/apt.conf.d/docker-clean + +RUN apt-get update && \ + apt-get install -y python3-venv && \ + apt-get clean && \ + python -m venv /opt/jet + +RUN pip3 install --no-cache-dir \ + black==24.4.2 \ + isort==5.13.2 \ + flake8==7.1.0 \ + pylint==3.2.6 \ + coverage \ + mypy + +COPY . /opt/megatron-lm + +WORKDIR /opt/megatron-lm + +##### For NVIDIANS only ##### +FROM main as jet +ARG CACHEBUST=0 +RUN --mount=type=secret,id=JET_INDEX_URLS \ + JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ + pip install "jet-client~=2.0" jet-api --upgrade $JET_INDEX_URLS +ENV PATH="$PATH:/opt/jet/bin" ### \ No newline at end of file diff --git a/GPT_pretraining.sh b/GPT_pretraining.sh deleted file mode 100644 index f0f3fc9c8f1bab02bf3a4eec6b5610688d251d9f..0000000000000000000000000000000000000000 --- a/GPT_pretraining.sh +++ /dev/null @@ -1,157 +0,0 @@ -#!/bin/bash - -# Runs the "7B" parameter model -export HSA_FORCE_FINE_GRAIN_PCIE=1 -export OMP_NUM_THREADS=1 -export NCCL_P2P_LEVEL=SYS - -export NCCL_ALGO=Ring -export NCCL_NCHANNELS_PER_PEER=16 -export NCCL_MIN_NCHANNELS=20 -export NCCL_IB_TIMEOUT=22 -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -export NCCL_NET_GDR_LEVEL=SYS -export NCCL_NET_GDR_READ=0 - - -CHECKPOINT_PATH=./tmp #$1 # -TENSORBOARD_LOGS_PATH=./tmp #$2 # -DATA_PATH="/datasets/oscar-1GB-gpt_text_document" #_text_document -VOCAB_PATH=./gpt2-vocab.json -MERGE_PATH=./gpt2-merges.txt - -GPT_MODEL_ARGS=( - --num-layers 12 - --hidden-size 768 - --num-attention-heads 12 - --ffn-hidden-size 3072 - --seq-length 1024 - --max-position-embeddings 1024 -) - -# export NVTE_FLASH_ATTN=1 # 走autlass -# export NVTE_FLASH_ATTN_TRITON=1 # 走triton_fa -# --transformer-impl transformer_engine - # --use-mcore-models -TRAINING_ARGS=( - --transformer-impl local - --use-legacy-models - --micro-batch-size 1 - --global-batch-size 60 #240 #512 #64 - --train-iters 100 - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.95 - --init-method-std 0.006 - --clip-grad 1.0 - --bf16 - --use-distributed-optimizer - --ckpt-format torch - --disable-bias-linear - --overlap-grad-reduce - --attention-dropout 0 - --hidden-dropout 0 - --ddp-average-in-collective - --recompute-granularity full - --recompute-num-layers 5 - --recompute-method block - --no-gradient-accumulation-fusion - --swiglu - --lr 3.0e-5 - --lr-decay-style cosine - --min-lr 3.0e-6 - --lr-warmup-iters 1 -) -MODEL_PARALLEL_ARGS=( - --sequence-parallel - --tensor-model-parallel-size 2 - --pipeline-model-parallel-size 2 -) - -DATA_ARGS=( - --data-path $DATA_PATH - --split 949,50,1 - --untie-embeddings-and-output-weights - --use-rotary-position-embeddings - --normalization RMSNorm - --no-position-embedding - --vocab-file $VOCAB_PATH - --merge-file $MERGE_PATH - --tokenizer-type GPT2BPETokenizer -) - -EVAL_AND_LOGGING_ARGS=( - --log-interval 1 - --save-interval 10000 - --eval-interval 1000 - --save $CHECKPOINT_PATH - --load $CHECKPOINT_PATH - --eval-iters 10 - --tensorboard-dir $TENSORBOARD_LOGS_PATH -) - -RANK=$OMPI_COMM_WORLD_RANK -LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK -WORLD_SIZE=$OMPI_COMM_WORLD_SIZE -DIST_URL=${1} -DIST_PORT=34566 - -DISTRIBUTED_ARGS=( - --rank ${RANK} - --world-size ${WORLD_SIZE} - --local-rank ${LOCAL_RANK} - --dist-url tcp://${DIST_URL}:${DIST_PORT} -) - -APP="python -u pretrain_gpt.py \ - ${GPT_MODEL_ARGS[@]} \ - ${TRAINING_ARGS[@]} \ - ${MODEL_PARALLEL_ARGS[@]} \ - ${DATA_ARGS[@]} \ - ${EVAL_AND_LOGGING_ARGS[@]} \ - ${DISTRIBUTED_ARGS[@]} \ -" - -case ${LOCAL_RANK} in -[0]) - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -# ${APP} - numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[1]) - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -# ${APP} - numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[2]) - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -# ${APP} - numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[3]) - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -# ${APP} - numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[4]) - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -# ${APP} - numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[5]) - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -# ${APP} - numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[6]) - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -# ${APP} - numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[7]) - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -# ${APP} - numactl --cpunodebind=0 --membind=0 ${APP} - ;; -esac diff --git a/LICENSE b/LICENSE index b4193aff5025430b3352e6c601777be0e7565d6b..57e1320a1117934ee57f66410779a61b987ec3a1 100644 --- a/LICENSE +++ b/LICENSE @@ -1,272 +1,273 @@ -The following applies to all files unless otherwise noted: - -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --- - -This repository also contains code from Hugging Face Inc., Google Research, -Facebook (from their Fairseq, Dino, and ParlAI projects), Microsoft (from their -Swin-Transformer project), Philip Popien, the Mamba project (Tri Dao and -Albert Gu), and the Triton language and compiler project (Philippe Tillet and -OpenAI). Files from these organizations have notices at the top of each file. -Below are licenses used in those files, as indicated. - - --------------------------------------------------------------------------------- --- LICENSE FOR Facebook, huggingface, Google Research, LLaVA, and Mamba code -- - - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - --------------------------------------------------------------------------------- -LICENSE FOR -Facebook, Inc. and its affiliates, -Meta Platforms, Inc. and its affiliates, -Microsoft Corporation, -OpenGVLab/InternVL, and -Triton language and compiler. - -MIT License - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - +The following applies to all files unless otherwise noted: + +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-- + +This repository also contains code from Hugging Face Inc., Google Research, +Facebook (from their Fairseq, Dino, and ParlAI projects), Microsoft (from their +Swin-Transformer project), Philip Popien, the Mamba project (Tri Dao and +Albert Gu), and the Triton language and compiler project (Philippe Tillet and +OpenAI). Files from these organizations have notices at the top of each file. +Below are licenses used in those files, as indicated. + + +-------------------------------------------------------------------------------------- +-- LICENSE FOR Facebook, huggingface, Google Research, LLaVA, Mamba, and vLLM code -- + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +-------------------------------------------------------------------------------- +LICENSE FOR +Facebook, Inc. and its affiliates, +Meta Platforms, Inc. and its affiliates, +Microsoft Corporation, +OpenGVLab/InternVL, +Triton language and compiler, +and DeepSeek. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/Llama_pretraining.sh b/Llama_pretraining.sh deleted file mode 100644 index 5f9d820be8908bfaef6e590f1eac17ba61952ce9..0000000000000000000000000000000000000000 --- a/Llama_pretraining.sh +++ /dev/null @@ -1,211 +0,0 @@ -#!/bin/bash -set -eux - -#export FLASH_ATTENTION_PRINT_PARAM=1 -# Runs the "7B" parameter model -export HSA_FORCE_FINE_GRAIN_PCIE=1 -export OMP_NUM_THREADS=1 -export NCCL_P2P_LEVEL=PXB # SYS - -#export HIP_ALLOC_INITIALIZE=0 -#export GPU_MAX_HW_QUEUES=20 - -export NCCL_ALGO=Ring -export NCCL_NCHANNELS_PER_PEER=16 -export NCCL_MIN_NCHANNELS=20 -export NCCL_IB_TIMEOUT=22 -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -export NCCL_IB_HCA=mlx5_1,mlx5_2 -export NCCL_NET_GDR_LEVEL=7 -export NCCL_NET_GDR_READ=1 -export GLOG_minloglevel=3 # 打印error级别的nccl日志 -source /opt/dtk/env.sh -# 导入hipblaslt库 -# export LD_LIBRARY_PATH=/data/hipblaslt-install-0904/lib:$LD_LIBRARY_PATH -# 更新rocblas -# export LD_LIBRARY_PATH=/data/rocblas-install_qwen1211/lib:$LD_LIBRARY_PATH -# export LD_LIBRARY_PATH=/data/rocblas-install_qwen1228/lib:$LD_LIBRARY_PATH -# export LD_LIBRARY_PATH=/data/rocblas-install-0118-bf16/lib:$LD_LIBRARY_PATH - -# torch控制多流转单流 -export ALLREDUCE_STREAM_WITH_COMPUTE=1 -export SENDRECV_STREAM_WITH_COMPUTE=1 - -# prof采集添加同步, 避免卡顿 -# export GPU_FLUSH_ON_EXECUTION=1 -# export HIP_DIRECT_DISPATCH=0 - -# 采集rocblas size -# export ROCBLAS_LAYER=3 -# 采集 fa size -# export FLASH_ATTENTION_PRINT_PARAM=1 - -#增加编译缓存 -export cache_size_limit=64 - -CHECKPOINT_PATH=./tmp_7b #$1 # -TENSORBOARD_LOGS_PATH=./tmp_7b #$2 # -DATA_PATH="/data/datasets/nemo_pretrain/oscar-1GB/oscar-1GB-llama_text_document" #_text_document - -GPT_MODEL_ARGS=( - --num-layers 32 - --hidden-size 4096 - --ffn-hidden-size 11008 - --num-attention-heads 32 - --max-position-embeddings 4096 - - --normalization RMSNorm - --position-embedding-type rope - --untie-embeddings-and-output-weights # 分开处理embed和输出权重, 增加灵活性 -) - -# export NVTE_FLASH_ATTN=1 # 走cutlass -export NVTE_FLASH_ATTN_TRITON=1 # 走triton_fa -# --transformer-impl transformer_engine # 走core用这两组参数 - # --use-mcore-models - # --transformer-impl local # 走legacy用这两组参数 - # --use-legacy-models -TRAINING_ARGS=( - --transformer-impl local # 走legacy用这两组参数 - --use-legacy-models - --micro-batch-size 1 - --global-batch-size 60 #240 #60 #512 #64 - --train-iters 10 - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.95 - --init-method-std 0.006 - --clip-grad 1.0 - --bf16 - # --fp16 # 开启fp16需要指定loss-scale - # --loss-scale 1024 - --use-distributed-optimizer - --disable-bias-linear - --attention-dropout 0 - --hidden-dropout 0 - --no-gradient-accumulation-fusion - --swiglu - --lr 3.0e-5 - --lr-decay-style cosine - --min-lr 3.0e-6 - --lr-warmup-iters 1 - --ckpt-format torch - --ddp-average-in-collective # 在dp阶段通信中, 梯度或参数将被直接平均, 而不是先求和(到一个设备)再平均 - # --recompute-granularity full # 开启重计算降低显存增加耗时 - # --recompute-num-layers 5 #0 # - # --recompute-method block - --overlap-grad-reduce # 重叠ddp grad reduce - # --tp-comm-overlap # tensor parallel comm和gemm重叠 - # --tp-comm-overlap-rs-dgrad # reduce-scatter和dgrad gemm重叠 - --use-flash-attn-triton -) -# --use-flash-attn-cutlass # cutlass fa -# --use-flash-attn-triton # triton fa - -MODEL_PARALLEL_ARGS=( - --sequence-parallel - --tensor-model-parallel-size 1 - --pipeline-model-parallel-size 2 -) - -DATA_ARGS=( - --data-path $DATA_PATH - --seq-length 4096 #4096 - --split 949,50,1 - --tokenizer-type Llama2Tokenizer - --tokenizer-model /data/model_weights/llama2_7b_hf/tokenizer.model -) - -EVAL_AND_LOGGING_ARGS=( - --log-interval 1 - --log-throughput - --save-interval 1000 - --eval-interval 1000 - --save $CHECKPOINT_PATH - --load $CHECKPOINT_PATH - --eval-iters 10 - --tensorboard-dir $TENSORBOARD_LOGS_PATH -) - -PROFILE_ARGS=( - --profile - --profile-step-start 4 - --profile-step-end 5 - --use-pytorch-profiler - --profile-ranks 0 1 2 3 4 5 6 7 - --profile-dir prof_data -) - -RANK=$OMPI_COMM_WORLD_RANK -LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK -WORLD_SIZE=$OMPI_COMM_WORLD_SIZE -DIST_URL=${1} -DIST_PORT=34567 - -DISTRIBUTED_ARGS=( - --rank ${RANK} - --world-size ${WORLD_SIZE} - --local-rank ${LOCAL_RANK} - --dist-url tcp://${DIST_URL}:${DIST_PORT} -) - -APP="python -u pretrain_gpt.py \ - ${GPT_MODEL_ARGS[@]} \ - ${TRAINING_ARGS[@]} \ - ${MODEL_PARALLEL_ARGS[@]} \ - ${DATA_ARGS[@]} \ - ${EVAL_AND_LOGGING_ARGS[@]} \ - ${DISTRIBUTED_ARGS[@]} \ - -" -# 开启profile -# ${PROFILE_ARGS[@]} \ - -# export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # # 4,5,6,7 #, -# export CUDA_VISIBLE_DEVICES=4,5,6,7 # 0,1,2,3, -# ${APP} - -# 使用numactl绑定 -case ${LOCAL_RANK} in -[0]) - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP} - numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[1]) - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP} - numactl --cpunodebind=1 --membind=1 ${APP} - ;; -[2]) - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP} - numactl --cpunodebind=2 --membind=2 ${APP} - ;; -[3]) - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - numactl --cpunodebind=3 --membind=3 ${APP} - # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[4]) - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - numactl --cpunodebind=4 --membind=4 ${APP} - # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[5]) - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - numactl --cpunodebind=5 --membind=5 ${APP} - # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[6]) - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - numactl --cpunodebind=6 --membind=6 ${APP} - # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP} - ;; -[7]) - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - numactl --cpunodebind=7 --membind=7 ${APP} - # hipprof --hip-trace --trace-off numactl --cpunodebind=0 --membind=0 ${APP} - ;; -esac diff --git a/README.md b/README.md index f5e8385c0a419ef2e9eb4409558200a8438001f8..d69e737f36c18f0eccb63e655075074454d252e1 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ # 更新日志 +2025.3.14适配最新代码,shell启动脚本在examples对应模型目录下 + 2024.12.16适配了torch prof 使用方法: 启动脚本中添加下列参数, 即可采集对应的prof信息 @@ -23,9 +25,6 @@ ```python # 采集torchprof mpirun -np 8 --allow-run-as-root train_mixtral_8x7B_1nodes.sh localhost --profiling=torch - -# 采集hipprof -mpirun -np 8 --allow-run-as-root train_mixtral_8x7B_1nodes.sh localhost --profiling=hip ``` ```bash @@ -38,14 +37,6 @@ TORCH_PROFIE_ARGS=( --profile-ranks 0 3 # 采集全局rank 第0和3 --profile-dir ./prof_data # prof文件的保存目录 ) - -HIP_PROFIE_ARGS=( - --profile - --profile-ranks 0 1 2 3 4 5 6 7 - --profile-step-start 4 - --profile-step-end 5 - --use-hip-profiler -) ``` diff --git a/docs/llama_mistral.md b/docs/llama_mistral.md index 11601fd44f6d2e6c71b2817eeaf42f54ae29cb5f..81f158448984637bdd25d53db8829aea5c857e87 100644 --- a/docs/llama_mistral.md +++ b/docs/llama_mistral.md @@ -1,480 +1,444 @@ -# Llama, Mistral and other Llama-like model support in Megatron-LM - -NOTE: In order to simplify code we now only support converting llama-3.x and mistral checkpoints downloaded from Huggingface. - -The [Llama-2](https://ai.meta.com/llama/) and [Llama-3](https://llama.meta.com/) family of models are an open-source set of pretrained & finetuned (for chat) models that have achieved strong results across a wide set of benchmarks. At their times of release, both Llama-2 and Llama-3 models achieved among the best results for open-source models, and were competitive with leading closed-source models (see https://arxiv.org/pdf/2307.09288.pdf and https://ai.meta.com/blog/meta-llama-3/). - -Similarly, [Mistral-7b](https://mistral.ai/news/announcing-mistral-7b/) is an open-source model with pretrained and finetuned (for chat) variants that achieve strong benchmark results. - -Architecturally Llama-2, Llama-3 and Mistral-7b are very similar. As such Megatron can support loading checkpoints from all three for inference and finetuning. Converting the checkpoints and loading them is slightly different for each model and is detailed for each below. - -# Llama-2 - -Llama-2 checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of three steps: - -1. Get access to download the checkpoints. -2. Convert the checkpoints from Meta/Huggingface format to Megatron format. -3. Setup arguments for launching the model. - -The following sections detail these steps. The final section lists benchmark result comparisons between: 1) Llama-2 inference code running the Meta-format checkpoints, and 2) Megatron inference code running the converted checkpoints. - -## Contents - * [Download Meta or Huggingface checkpoints](#download-meta-or-huggingface-checkpoints) - * [Convert checkpoint format](#convert-checkpoint-format) - * [Meta format](#meta-format) - * [Huggingface format](#huggingface-format) - * [Launch model](#launch-model) - * [Megatron](#launch-megatron) - * [Meta](#launch-meta) - * [Huggingface](#launch-hf) - * [Benchmark results](#benchmark-results) - -## Download Meta or Huggingface checkpoints - -Users must first apply for access to download the Llama-2 checkpoints either directly from [Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) or through [Huggingface](https://huggingface.co/docs/transformers/main/model_doc/llama2) (HF). The checkpoints are available in two formats, Meta's native format (available from both the Meta and HF links), and HF's format (available only from HF). Either format can be converted to Megatron, as detailed next. - -## Convert checkpoint format - -We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16. - -### Meta format - -The Meta format checkpoints are converted to HF format as an intermediate step before converting to Megatron format. The `transformers` package is required, and must have version >=4.31.0 (e.g., `pip install transformers>=4.31.0`). (**Note**: we have specifically tested with versions `4.31.0` and `4.32.0`; your experience may vary with newer versions.) Assuming the downloaded checkpoints are in `$CHECKPOINT_DIR` (with separate sub-directories for 7B, 13B, 70B, etc.), the following example command can be used to convert from Llama-2 format to HF format in bfloat16: - -``` -python tools/checkpoint/convert.py --model-type GPT \ -> --loader llama_mistral \ -> --saver megatron \ -> --checkpoint-type meta \ -> --model-size llama2-7B \ -> --load-dir $LLAMA_META_FORMAT_DIR \ -> --save-dir ${MEGATRON_FORMAT_DIR} \ -> --tokenizer-model ${TOKENIZER_MODEL} \ -> --target-tensor-parallel-size ${TP} \ -> --target-pipeline-parallel-size ${PP} \ -> --bf16 -``` - -Valid values for `--model-size` are `llama2-7B`, `llama2-13B`, and `llama2-70B` (for pretrained-only models), and `llama2-7Bf`, `llama2-13Bf`, and `llama2-70Bf` (for chat-finetuned models). - -### Huggingface format - -The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-2 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values: - -| Model size | Tensor parallel size (`TP`) | -| ---------- | --------------------------- | -| 7B | 1 | -| 13B | 2 | -| 70B | 8 | - -Using these values for `TP`, along with the path to the Llama-2 tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format: - -``` -$>: python tools/checkpoint/convert.py \ - > --model-type GPT \ - > --loader llama_mistral \ - > --saver megatron \ - > --target-tensor-parallel-size ${TP} \ - > --checkpoint-type hf - > --load-dir ${HF_FORMAT_DIR} \ - > --save-dir ${MEGATRON_FORMAT_DIR} \ - > --tokenizer-model ${TOKENIZER_MODEL} -``` - -After this conversion, we are ready to load the checkpoints into a Megatron GPT model. - -## Launch model - -### Launch Megatron - -If loading for either inference or finetuning, use the following arguments: - -``` ---tensor-model-parallel-size ${TP} \ ---pipeline-model-parallel-size 1 \ ---seq-length 4096 \ ---max-position-embeddings 4096 \ ---tokenizer-type Llama2Tokenizer \ ---tokenizer-model ${TOKENIZER_MODEL} \ ---load ${CHECKPOINT_DIR} \ ---exit-on-missing-checkpoint \ ---use-checkpoint-args \ ---no-load-optim \ ---no-load-rng \ ---untie-embeddings-and-output-weights \ ---use-rotary-position-embeddings \ ---normalization RMSNorm \ ---no-position-embedding \ ---no-masked-softmax-fusion \ ---attention-softmax-in-fp32 -``` - -### Launch Meta - -Meta checkpoints can be launched with: https://github.com/facebookresearch/llama - -### Launch Huggingface - -Huggingface checkpoints can be launched with: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - -## Benchmark results - -The tables below list the benchmark comparisons between native Llama-2 (using Meta's checkpoint and Meta's inference code) and Megatron (using a converted HF checkpoint and Megatron's inference code). - -The values are the percent error between Megatron and Llama-2, calculated using the formula: `| - | / `, where the type of score is detailed before each table. Across all tests (80 total per model size), the mean error is 0.15%. The small difference in benchmark scores between the two models is due to minor arithmetic differences in implementation that alter the numerics slightly. Some of the factors that influence this difference include: - -- Megatron performs batch matrix multiplications in a couple places, such as within self attention and in SwiGLU, that Llama performs separately. -- Megatron uses `torch.baddbmm` within self attention, versus Llama using `torch.matmul`. -- Megatron uses a `sin`/`cos` implementation for rotary position embeddings, versus Llama using a `polar`/`complex` implementation. -- Llama calls `torch.set_default_dtype(torch.float16)` during initialization, which Megatron does not. - -### Big Bench - -Score type: multiple choice grade. - -| bigbench / standard | 7b | 13b | 70b | -| -- | -- | -- | -- | -| date_understanding | 0.29% | 0.13% | 0.12% | -| general_knowledge | 0.00% | 0.00% | 0.00% | -| human_organs_senses | 0.00% | 0.00% | 0.00% | -| intent_recognition | 0.00% | 0.11% | 0.00% | -| riddle_sense | 0.00% | 0.00% | 0.00% | -| similarities_abstraction | 0.00% | 0.58% | 0.00% | -| simple_arithmetic_json_multiple_choice | 0.00% | 0.00% | 0.00% | -| undo_permutation | 0.19% | 0.19% | 0.18% | - -### Multilingual - -Score type: multiple choice grade. - -| multilingual / xcopa | 7b | 13b | 70b | -| -- | -- | -- | -- | -| en-template-mGPT-remove-punctuation | 0.08% | 0.00% | 0.00% | -| et-template-mGPT-remove-punctuation | 0.00% | 0.13% | 0.25% | -| ht-template-mGPT-remove-punctuation | 0.26% | 0.13% | 0.26% | -| id-template-mGPT-remove-punctuation | 0.11% | 0.00% | 0.19% | -| it-template-mGPT-remove-punctuation | 0.00% | 0.10% | 0.09% | -| qu-template-mGPT-remove-punctuation | 0.00% | 0.00% | 0.27% | -| sw-template-mGPT-remove-punctuation | 0.14% | 0.13% | 0.13% | -| th-template-mGPT-remove-punctuation | 0.25% | 0.13% | 0.13% | -| tr-template-mGPT-remove-punctuation | 0.26% | 0.00% | 0.34% | -| vi-template-mGPT-remove-punctuation | 0.00% | 0.11% | 0.00% | -| zh-template-mGPT-remove-punctuation | 0.00% | 0.10% | 0.09% | - -### LM Evaluation Harness - -Score type: multiple choice grade. - -| lm-eval | 7b | 13b | 70b | -| -- | -- | -- | -- | -| boolq | 0.04% | 0.04% | 0.07% | -| hellaswag | 0.02% | 0.03% | 0.03% | -| piqa | 0.00% | 0.00% | 0.07% | -| winogrande | 0.00% | 0.11% | 0.20% | - -### MMLU - -Score type: multiple choice grade. - -Note: the number in brackets is the number of sub-tasks for each supercategory. - -| mmlu | 7b | 13b | 70b | -| -- | -- | -- | -- | -| stem [18] | 0.79% | 0.05% | 0.01% | -| humanities [13] | 0.19% | 0.01% | 0.02% | -| other (business, health, misc.) [14] | 0.08% | 0.06% | 0.12% | -| social sciences [12] | 0.37% | 0.21% | 0.01% | - -# Llama-3 - -Llama-3 checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of several steps: - -1. Get access to download the checkpoints (weights and tokenizer). -2. Convert the checkpoints from Huggingface format to Megatron format. -3. (Optional) Validate converted checkpoints -4. Setup arguments for launching the model. - -The following sections detail these steps. - -## Contents - * [Download Huggingface checkpoints](#download-huggingface-checkpoints) - * [Convert checkpoint format](#convert-checkpoint-format) - * [Huggingface format](#huggingface-format) - * [Validate checkpoint](#optional-validate-checkpoint) - * [Launch model](#launch-model) - -## Download Huggingface checkpoints - -Users must first apply for access to download the Llama-3 checkpoints from [Huggingface](https://huggingface.co/meta-llama). - -## Convert checkpoint format - -We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16. - -### Huggingface format - -The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-3 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values: - -| Model size | Tensor parallel size (`TP`) | -| ---------- | --------------------------- | -| 8B | 1 | -| 70B | 8 | - -Using these values for `TP`, along with the path to the Llama-3 tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format: - -``` -$>: python tools/checkpoint/convert.py \ - > --bf16 \ - > --model-type GPT \ - > --loader llama_mistral \ - > --saver mcore \ - > --target-tensor-parallel-size ${TP} \ - > --checkpoint-type hf - > --load-dir ${HF_FORMAT_DIR} \ - > --save-dir ${MEGATRON_FORMAT_DIR} \ - > --tokenizer-model ${TOKENIZER_MODEL} - > --model-size llama3-8B \ -``` - -Valid values for `--model-size` are `llama3-8B` and `llama3-70B` (for pretrained-only models), and `llama3-8Bf` and `llama3-70Bf` (for chat-finetuned models). - -After this conversion, we are ready to load the checkpoints into a Megatron GPT model. - -## (Optional) Validate checkpoints - -A Megatron-LM text generation server for Llama3 can be launched using the script `examples/llama_mistral/run_text_generation_llama3.sh `. - -Once running, query the server with `curl 'http://:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":[""], "tokens_to_generate":100, "top_k":1}'`. - -A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/llama_mistral/huggingface_reference.py --model_path --prompt `. - -## Launch model - -If loading for either inference or finetuning, use the following arguments: - -``` ---tensor-model-parallel-size ${TP} \ ---pipeline-model-parallel-size 1 \ ---seq-length 8192 \ ---max-position-embeddings 8192 \ ---tokenizer-type HuggingFaceTokenizer \ ---tokenizer-model ${TOKENIZER_MODEL} \ ---load ${CHECKPOINT_DIR} \ ---exit-on-missing-checkpoint \ ---use-checkpoint-args \ ---no-load-optim \ ---no-load-rng \ ---untie-embeddings-and-output-weights \ ---normalization RMSNorm \ ---position-embedding-type rope \ ---no-masked-softmax-fusion \ ---attention-softmax-in-fp32 \ ---disable-bias-linear \ ---transformer-impl transformer_engine \ ---group-query-attention 8 \ ---attention-dropout 0.0 \ ---hidden-dropout 0.0 \ ---rotary-base 500000 \ ---rotary-percent 1.0 \ ---ffn-hidden-size 14336 \ ---num-attention-heads 32 \ ---swiglu \ ---bf16 \ -``` - -# Llama-3.1 - -Llama-3 checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of several steps: - -1. Get access to download the checkpoints (weights and tokenizer). -2. Convert the checkpoints from Huggingface format to Megatron format. -3. (Optional) Validate converted checkpoints -4. Setup arguments for launching the model. - -The following sections detail these steps. - -## Contents - * [Download Huggingface checkpoints](#download-huggingface-checkpoints) - * [Convert checkpoint format](#convert-checkpoint-format) - * [Huggingface format](#huggingface-format) - * [Validate checkpoint](#optional-validate-checkpoint) - * [Launch model](#launch-model) - -## Download Huggingface checkpoints - -Users must first apply for access to download the Llama-3 checkpoints from [Huggingface](https://huggingface.co/meta-llama). - -## Convert checkpoint format - -We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16. - -### Huggingface format - -The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-3 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values: - -| Model size | Tensor parallel size (`TP`) | -| ---------- | --------------------------- | -| 8B | 1 | -| 70B | 8 | - -Using these values for `TP`, along with the path to the Llama-3 tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format: - -``` -$>: python tools/checkpoint/convert.py \ - > --bf16 \ - > --model-type GPT \ - > --loader llama_mistral \ - > --saver mcore \ - > --target-tensor-parallel-size ${TP} \ - > --checkpoint-type hf - > --load-dir ${HF_FORMAT_DIR} \ - > --save-dir ${MEGATRON_FORMAT_DIR} \ - > --tokenizer-model ${TOKENIZER_MODEL} - > --model-size llama3-8B \ -``` - -Valid values for `--model-size` are `llama3.1-8B` and `llama3.1-70B` (for pretrained-only models), and `llama3.1-8Bf` and `llama3.1-70Bf` (for chat-finetuned models). - -After this conversion, we are ready to load the checkpoints into a Megatron GPT model. - -## (Optional) Validate checkpoints - -A Megatron-LM text generation server for Llama3.1 can be launched using the script `examples/llama_mistral/run_text_generation_llama3.1.sh `. - -Once running, query the server with `curl 'http://:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":[""], "tokens_to_generate":100, "top_k":1}'`. - -A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/llama_mistral/huggingface_reference.py --model_path --prompt `. - -## Launch model - -If loading for either inference or finetuning, use the following arguments: - -``` ---tensor-model-parallel-size ${TP} \ ---pipeline-model-parallel-size 1 \ ---seq-length 8192 \ ---max-position-embeddings 131072 \ ---tokenizer-type HuggingFaceTokenizer \ ---tokenizer-model ${TOKENIZER_MODEL} \ ---load ${CHECKPOINT_DIR} \ ---exit-on-missing-checkpoint \ ---use-checkpoint-args \ ---no-load-optim \ ---no-load-rng \ ---untie-embeddings-and-output-weights \ ---normalization RMSNorm \ ---position-embedding-type rope \ ---no-masked-softmax-fusion \ ---attention-softmax-in-fp32 \ ---disable-bias-linear \ ---transformer-impl transformer_engine \ ---group-query-attention 8 \ ---attention-dropout 0.0 \ ---hidden-dropout 0.0 \ ---rotary-base 500000 \ ---rotary-percent 1.0 \ ---use-rope-scaling \ ---ffn-hidden-size 14336 \ ---num-attention-heads 32 \ ---swiglu \ ---bf16 \ -``` - -# Mistral-7b - -Megatron currently supports loading the v0.3 release of Mistral-7b (which does not use sliding window attention and offers a larger 32768 vocabulary) for inference and finetuning. Loading these checkpoints consists of several steps: - -1. Get access to download the checkpoints (weights and tokenizer). -2. Convert the checkpoints from HuggingFace format to Megatron format. -3. (Optional) Validate converted checkpoints -4. Setup arguments for launching the model. - -The following sections detail these steps. - -## Contents - * [Download Huggingface checkpoints](#download-huggingface-checkpoints) - * [Convert checkpoint format](#convert-checkpoint-format) - * [(Optional) Validate checkpoint](#optional-validate-checkpoint) - * [Launch model](#launch-model) - -## Download Huggingface checkpoints - -Users must first apply for access to download the Mistral-7b checkpoints through [Huggingface](https://huggingface.co/mistralai/Mistral-7B-v0.3) (HF). - -## Convert checkpoint format - -The HF checkpoints can be converted to Megatron format by using Megatron's own Mistral checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). - -Using the path to the Mistral tokenizer model (downloaded alongside the HF checkpoint), run the following command from the root of your Megatron source code to convert from HF format to mcore format: - -``` -$>: python tools/checkpoint/convert.py \ - > --bf16 \ - > --model-type GPT \ - > --loader llama_mistral \ - > --saver mcore \ - > --target-tensor-parallel-size ${TP} \ - > --checkpoint-type hf \ - > --load-dir ${HF_FORMAT_DIR} \ - > --save-dir ${MEGATRON_FORMAT_DIR} \ - > --tokenizer-model ${TOKENIZER_MODEL} \ - > --model-size mistral-7B \ -``` - -Valid values for `--model-size` are mistral-7B for the pretrained model or mistral-7Bf for the chat fine-tuned model. - -After this conversion, we are ready to load the checkpoints into an mcore GPT model. - -## (Optional) Validate checkpoints - -A Megatron-LM text generation server for Mistral-7B can be launched using the script `examples/llama_mistral/run_text_generation_mistral.sh `. - -Once running, query the server with `curl 'http://:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":[""], "tokens_to_generate":100, "top_k":1}'`. - -A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/llama_mistral/huggingface_reference.py --model_path --prompt `. - -## Launch model - -If loading for either inference or finetuning, use the following arguments: - -``` ---tensor-model-parallel-size ${TP} \ ---pipeline-model-parallel-size 1 \ ---seq-length 4096 \ ---max-position-embeddings 4096 \ ---tokenizer-type HuggingFaceTokenizer \ ---tokenizer-model ${TOKENIZER_MODEL} \ ---load ${CHECKPOINT_DIR} \ ---exit-on-missing-checkpoint \ ---use-checkpoint-args \ ---no-load-optim \ ---no-load-rng \ ---untie-embeddings-and-output-weights \ ---normalization RMSNorm \ ---position-embedding-type rope \ ---no-masked-softmax-fusion \ ---attention-softmax-in-fp32 ---apply-layernorm-1p \ ---transformer-impl transformer_engine \ ---group-query-attention 8 \ ---disable-bia-linear \ ---rotary-base 1000000 \ ---rotary-percent 1.0 \ ---swiglu \ ---ffn-hidden-size 14336 \ ---num-attention-heads 32 -``` - -# Other Llama-like model support - -*Note: Experimental* - -Many models such as Yi-34B use the Llama architecture and may be converted from HuggingFace to Megatron using the commands in [Llama3](#llama-3). - -# Known numerical differences - -It is not expected that the megatron and Huggingface implementations of llama3.x and mistral models will produce numerically identical results. There are multiple points where small numerical differences are expected. This is a non-exhaustive list: - -1. TransformerEngine (TE) uses the model params_dtype inside RMSNorm whereas the Huggingface implementation uses fp32. See for details: https://github.com/NVIDIA/TransformerEngine/issues/1132 -2. Huggingface `transformers` implements the q, k and v projections in self-attention as separate GEMMs whereas mcore combines them into a single GEMM for efficiency. This leads to small numerical differences. - +# Llama, Mistral and other Llama-like model support in Megatron-LM + +NOTE: In order to simplify code we now only support converting llama-3.x and mistral checkpoints downloaded from Huggingface. + +The [Llama-2](https://ai.meta.com/llama/) and [Llama-3.x](https://llama.meta.com/) family of models are an open-source set of pretrained & finetuned (for chat) models that have achieved strong results across a wide set of benchmarks. At their times of release, both Llama-2 and Llama-3 models achieved among the best results for open-source models, and were competitive with leading closed-source models (see https://arxiv.org/pdf/2307.09288.pdf and https://ai.meta.com/blog/meta-llama-3/). + +Similarly, [Mistral-7b](https://mistral.ai/news/announcing-mistral-7b/) is an open-source model with pretrained and finetuned (for chat) variants that achieve strong benchmark results. + +Architecturally Llama-2, Llama-3 and Mistral-7b are very similar. As such Megatron can support loading checkpoints from all three for inference and finetuning. Converting the checkpoints and loading them is slightly different for each model and is detailed for each below. + +# Contents + +- [Llama, Mistral and other Llama-like model support in Megatron-LM](#llama-mistral-and-other-llama-like-model-support-in-megatron-lm) +- [Contents](#contents) +- [Llama-2](#llama-2) + - [Download Meta or Huggingface checkpoints](#download-meta-or-huggingface-checkpoints) + - [Convert checkpoint format](#convert-checkpoint-format) + - [Meta format](#meta-format) + - [Huggingface format](#huggingface-format) + - [Launch model](#launch-model) + - [Launch Megatron](#launch-megatron) + - [Launch Meta](#launch-meta) + - [Launch Huggingface](#launch-huggingface) + - [Benchmark results](#benchmark-results) + - [Big Bench](#big-bench) + - [Multilingual](#multilingual) + - [LM Evaluation Harness](#lm-evaluation-harness) + - [MMLU](#mmlu) +- [Llama-3.x](#llama-3x) + - [Download Huggingface checkpoints](#download-huggingface-checkpoints) + - [Convert checkpoint format](#convert-checkpoint-format-1) + - [Huggingface format](#huggingface-format-1) + - [(Optional) Validate checkpoints](#optional-validate-checkpoints) + - [Launch model](#launch-model-1) +- [Mistral-7b](#mistral-7b) + - [Download Huggingface checkpoints](#download-huggingface-checkpoints-2) + - [Convert checkpoint format](#convert-checkpoint-format-3) + - [(Optional) Validate checkpoints](#optional-validate-checkpoints-2) + - [Launch model](#launch-model-3) +- [Other Llama-like model support](#other-llama-like-model-support) +- [Known numerical differences](#known-numerical-differences) +- [Using legacy model format](#using-legacy-model-format) + + +# Llama-2 + +Llama-2 checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of three steps: + +1. Get access to download the checkpoints. +2. Convert the checkpoints from Meta/Huggingface format to Megatron format. +3. Setup arguments for launching the model. + +The following sections detail these steps. The final section lists benchmark result comparisons between: 1) Llama-2 inference code running the Meta-format checkpoints, and 2) Megatron inference code running the converted checkpoints. + +## Download Meta or Huggingface checkpoints + +Users must first apply for access to download the Llama-2 checkpoints either directly from [Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) or through [Huggingface](https://huggingface.co/docs/transformers/main/model_doc/llama2) (HF). The checkpoints are available in two formats, Meta's native format (available from both the Meta and HF links), and HF's format (available only from HF). Either format can be converted to Megatron, as detailed next. + +## Convert checkpoint format + +We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16. + +### Meta format + +The Meta format checkpoints are converted to HF format as an intermediate step before converting to Megatron format. The `transformers` package is required, and must have version >=4.31.0 (e.g., `pip install transformers>=4.31.0`). (**Note**: we have specifically tested with versions `4.31.0` and `4.32.0`; your experience may vary with newer versions.) Assuming the downloaded checkpoints are in `$CHECKPOINT_DIR` (with separate sub-directories for 7B, 13B, 70B, etc.), the following example command can be used to convert from Llama-2 format to HF format in bfloat16: + +``` +python tools/checkpoint/convert.py \ +> --model-type GPT \ +> --loader llama_mistral \ +> --load-dir ${META_FORMAT_DIR} \ +> --model-size ${MODEL_SIZE} \ +> --checkpoint-type meta \ +> --tokenizer-model ${TOKENIZER_MODEL} \ +> --saver core \ +> --save-dir ${MEGATRON_FORMAT_DIR} \ +> --target-tensor-parallel-size ${TP} \ +> --target-pipeline-parallel-size ${PP} \ +> --bf16 +``` + +Valid values for `--model-size` are `llama2-7B`, `llama2-13B`, and `llama2-70B` (for pretrained-only models), and `llama2-7Bf`, `llama2-13Bf`, and `llama2-70Bf` (for chat-finetuned models). + +### Huggingface format + +The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-2 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values: + +| Model size | Tensor parallel size (`TP`) | +| ---------- | --------------------------- | +| 7B | 1 | +| 13B | 2 | +| 70B | 8 | + +Using these values for `TP`, along with the path to the Llama-2 tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format: + +``` +python tools/checkpoint/convert.py \ +> --model-type GPT \ +> --loader llama_mistral \ +> --load-dir ${HF_FORMAT_DIR} \ +> --model-size ${MODEL_SIZE} \ +> --checkpoint-type hf \ +> --tokenizer-model ${TOKENIZER_MODEL} \ +> --saver core \ +> --save-dir ${MEGATRON_FORMAT_DIR} \ +> --target-tensor-parallel-size ${TP} \ +> --target-pipeline-parallel-size ${PP} \ +> --bf16 +``` + +After this conversion, we are ready to load the checkpoints into a Megatron GPT model. + +## Launch model + +### Launch Megatron + +If loading for either inference or finetuning, use the following arguments: + +``` +--tensor-model-parallel-size ${TP} \ +--pipeline-model-parallel-size 1 \ +--seq-length 4096 \ +--max-position-embeddings 4096 \ +--tokenizer-type Llama2Tokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} \ +--load ${CHECKPOINT_DIR} \ +--exit-on-missing-checkpoint \ +--use-checkpoint-args \ +--no-load-optim \ +--no-load-rng \ +--untie-embeddings-and-output-weights \ +--use-rotary-position-embeddings \ +--normalization RMSNorm \ +--no-position-embedding \ +--no-masked-softmax-fusion \ +--attention-softmax-in-fp32 +``` + +**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format). + +### Launch Meta + +Meta checkpoints can be launched with: https://github.com/facebookresearch/llama + +### Launch Huggingface + +Huggingface checkpoints can be launched with: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + +## Benchmark results + +The tables below list the benchmark comparisons between native Llama-2 (using Meta's checkpoint and Meta's inference code) and Megatron (using a converted HF checkpoint and Megatron's inference code). + +The values are the percent error between Megatron and Llama-2, calculated using the formula: `| - | / `, where the type of score is detailed before each table. Across all tests (80 total per model size), the mean error is 0.15%. The small difference in benchmark scores between the two models is due to minor arithmetic differences in implementation that alter the numerics slightly. Some of the factors that influence this difference include: + +- Megatron performs batch matrix multiplications in a couple places, such as within self attention and in SwiGLU, that Llama performs separately. +- Megatron uses `torch.baddbmm` within self attention, versus Llama using `torch.matmul`. +- Megatron uses a `sin`/`cos` implementation for rotary position embeddings, versus Llama using a `polar`/`complex` implementation. +- Llama calls `torch.set_default_dtype(torch.float16)` during initialization, which Megatron does not. + +### Big Bench + +Score type: multiple choice grade. + +| bigbench / standard | 7b | 13b | 70b | +| -- | -- | -- | -- | +| date_understanding | 0.29% | 0.13% | 0.12% | +| general_knowledge | 0.00% | 0.00% | 0.00% | +| human_organs_senses | 0.00% | 0.00% | 0.00% | +| intent_recognition | 0.00% | 0.11% | 0.00% | +| riddle_sense | 0.00% | 0.00% | 0.00% | +| similarities_abstraction | 0.00% | 0.58% | 0.00% | +| simple_arithmetic_json_multiple_choice | 0.00% | 0.00% | 0.00% | +| undo_permutation | 0.19% | 0.19% | 0.18% | + +### Multilingual + +Score type: multiple choice grade. + +| multilingual / xcopa | 7b | 13b | 70b | +| -- | -- | -- | -- | +| en-template-mGPT-remove-punctuation | 0.08% | 0.00% | 0.00% | +| et-template-mGPT-remove-punctuation | 0.00% | 0.13% | 0.25% | +| ht-template-mGPT-remove-punctuation | 0.26% | 0.13% | 0.26% | +| id-template-mGPT-remove-punctuation | 0.11% | 0.00% | 0.19% | +| it-template-mGPT-remove-punctuation | 0.00% | 0.10% | 0.09% | +| qu-template-mGPT-remove-punctuation | 0.00% | 0.00% | 0.27% | +| sw-template-mGPT-remove-punctuation | 0.14% | 0.13% | 0.13% | +| th-template-mGPT-remove-punctuation | 0.25% | 0.13% | 0.13% | +| tr-template-mGPT-remove-punctuation | 0.26% | 0.00% | 0.34% | +| vi-template-mGPT-remove-punctuation | 0.00% | 0.11% | 0.00% | +| zh-template-mGPT-remove-punctuation | 0.00% | 0.10% | 0.09% | + +### LM Evaluation Harness + +Score type: multiple choice grade. + +| lm-eval | 7b | 13b | 70b | +| -- | -- | -- | -- | +| boolq | 0.04% | 0.04% | 0.07% | +| hellaswag | 0.02% | 0.03% | 0.03% | +| piqa | 0.00% | 0.00% | 0.07% | +| winogrande | 0.00% | 0.11% | 0.20% | + +### MMLU + +Score type: multiple choice grade. + +Note: the number in brackets is the number of sub-tasks for each supercategory. + +| mmlu | 7b | 13b | 70b | +| -- | -- | -- | -- | +| stem [18] | 0.79% | 0.05% | 0.01% | +| humanities [13] | 0.19% | 0.01% | 0.02% | +| other (business, health, misc.) [14] | 0.08% | 0.06% | 0.12% | +| social sciences [12] | 0.37% | 0.21% | 0.01% | + +# Llama-3.x + +Llama-3.x checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of several steps: + +1. Get access to download the checkpoints (weights and tokenizer). +2. Convert the checkpoints from Huggingface format to Megatron format. +3. (Optional) Validate converted checkpoints +4. Setup arguments for launching the model. + +The following sections detail these steps. + +## Download Huggingface checkpoints + +Users must first apply for access to download the Llama-3.x checkpoints from [Huggingface](https://huggingface.co/meta-llama). + +## Convert checkpoint format + +We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16. + +### Huggingface format + +The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-3.x checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values: + +| Model size | Tensor parallel size (`TP`) | +| ---------- | --------------------------- | +| 1B | 1 | +| 3B | 1 | +| 8B | 1 | +| 70B | 8 | + +Using these values for `TP`, along with the path to the Llama-3.x tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format: + +``` +$>: python tools/checkpoint/convert.py \ + > --bf16 \ + > --model-type GPT \ + > --loader llama_mistral \ + > --saver core \ + > --target-tensor-parallel-size ${TP} \ + > --checkpoint-type hf \ + > --load-dir ${HF_FORMAT_DIR} \ + > --save-dir ${MEGATRON_FORMAT_DIR} \ + > --tokenizer-model ${TOKENIZER_MODEL} \ + > --model-size llama3 \ +``` + +After this conversion, we are ready to load the checkpoints into a Megatron GPT model. + +## (Optional) Validate checkpoints + +A Megatron-LM text generation server for Llama3 can be launched using the script `examples/inference/llama_mistral/run_text_generation_llama3.sh `. For Llama3.1, please use `examples/inference/llama_mistral/run_text_generation_llama3.1.sh`. + +Once running, query the server with `curl 'http://:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":[""], "tokens_to_generate":100, "top_k":1}'`. + +A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/llama_mistral/huggingface_reference.py --model_path --prompt `. + +## Launch model + +If loading for either inference or finetuning, use the following arguments for Llama 3.0: + +``` +--tensor-model-parallel-size ${TP} \ +--pipeline-model-parallel-size 1 \ +--seq-length 8192 \ +--max-position-embeddings 8192 \ +--tokenizer-type HuggingFaceTokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} \ +--load ${CHECKPOINT_DIR} \ +--exit-on-missing-checkpoint \ +--use-checkpoint-args \ +--no-load-optim \ +--no-load-rng \ +--untie-embeddings-and-output-weights \ +--normalization RMSNorm \ +--position-embedding-type rope \ +--no-masked-softmax-fusion \ +--attention-softmax-in-fp32 \ +--disable-bias-linear \ +--transformer-impl transformer_engine \ +--group-query-attention 8 \ +--attention-dropout 0.0 \ +--hidden-dropout 0.0 \ +--rotary-base 500000 \ +--rotary-percent 1.0 \ +--ffn-hidden-size 14336 \ +--num-attention-heads 32 \ +--swiglu \ +--bf16 \ +``` + +For Llama3.1 please use the following arguments: + +``` +--tensor-model-parallel-size ${TP} \ +--pipeline-model-parallel-size 1 \ +--seq-length 8192 \ +--max-position-embeddings 131072 \ +--tokenizer-type HuggingFaceTokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} \ +--load ${CHECKPOINT_DIR} \ +--exit-on-missing-checkpoint \ +--use-checkpoint-args \ +--no-load-optim \ +--no-load-rng \ +--untie-embeddings-and-output-weights \ +--normalization RMSNorm \ +--position-embedding-type rope \ +--no-masked-softmax-fusion \ +--attention-softmax-in-fp32 \ +--disable-bias-linear \ +--transformer-impl transformer_engine \ +--group-query-attention 8 \ +--attention-dropout 0.0 \ +--hidden-dropout 0.0 \ +--rotary-base 500000 \ +--rotary-percent 1.0 \ +--use-rope-scaling \ +--ffn-hidden-size 14336 \ +--num-attention-heads 32 \ +--swiglu \ +--bf16 \ +``` + +**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format). + +# Mistral-7b + +Megatron currently supports loading the v0.3 release of Mistral-7b (which does not use sliding window attention and offers a larger 32768 vocabulary) for inference and finetuning. Loading these checkpoints consists of several steps: + +1. Get access to download the checkpoints (weights and tokenizer). +2. Convert the checkpoints from HuggingFace format to Megatron format. +3. (Optional) Validate converted checkpoints +4. Setup arguments for launching the model. + +The following sections detail these steps. + +## Download Huggingface checkpoints + +Users must first apply for access to download the Mistral-7b checkpoints through [Huggingface](https://huggingface.co/mistralai/Mistral-7B-v0.3) (HF). + +## Convert checkpoint format + +The HF checkpoints can be converted to Megatron format by using Megatron's own Mistral checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). + +Using the path to the Mistral tokenizer model (downloaded alongside the HF checkpoint), run the following command from the root of your Megatron source code to convert from HF format to the Megatron core format: + +``` +$>: python tools/checkpoint/convert.py \ + > --bf16 \ + > --model-type GPT \ + > --loader llama_mistral \ + > --saver core \ + > --target-tensor-parallel-size ${TP} \ + > --checkpoint-type hf \ + > --load-dir ${HF_FORMAT_DIR} \ + > --save-dir ${MEGATRON_FORMAT_DIR} \ + > --tokenizer-model ${TOKENIZER_MODEL} \ + > --model-size mistral \ +``` + +After this conversion, we are ready to load the checkpoints into a Megatron core GPT model. + +## (Optional) Validate checkpoints + +A Megatron-LM text generation server for Mistral-7B can be launched using the script `examples/inference/llama_mistral/run_text_generation_mistral.sh `. + +Once running, query the server with `curl 'http://:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":[""], "tokens_to_generate":100, "top_k":1}'`. + +A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/inference/llama_mistral/huggingface_reference.py --model_path --prompt `. + +## Launch model + +If loading for either inference or finetuning, use the following arguments: + +``` +--tensor-model-parallel-size ${TP} \ +--pipeline-model-parallel-size 1 \ +--seq-length 4096 \ +--max-position-embeddings 4096 \ +--tokenizer-type HuggingFaceTokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} \ +--load ${CHECKPOINT_DIR} \ +--exit-on-missing-checkpoint \ +--use-checkpoint-args \ +--no-load-optim \ +--no-load-rng \ +--untie-embeddings-and-output-weights \ +--normalization RMSNorm \ +--position-embedding-type rope \ +--no-masked-softmax-fusion \ +--attention-softmax-in-fp32 +--apply-layernorm-1p \ +--transformer-impl transformer_engine \ +--group-query-attention 8 \ +--disable-bia-linear \ +--rotary-base 1000000 \ +--rotary-percent 1.0 \ +--swiglu \ +--ffn-hidden-size 14336 \ +--num-attention-heads 32 +``` + +**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format). + +# Other Llama-like model support + +*Note: Experimental* + +Many models such as Yi-34B and Qwen2.x use the Llama architecture and may be converted from HuggingFace to Megatron using the commands in [Llama-3.x](#llama-3x). + +# Known numerical differences + +It is not expected that the megatron and Huggingface implementations of llama3.x and mistral models will produce numerically identical results. There are multiple points where small numerical differences are expected. This is a non-exhaustive list: + +1. TransformerEngine (TE) uses the model params_dtype inside RMSNorm whereas the Huggingface implementation uses fp32. See for details: https://github.com/NVIDIA/TransformerEngine/issues/1132 +2. Huggingface `transformers` implements the q, k and v projections in self-attention as separate GEMMs whereas Megatron core combines them into a single GEMM for efficiency. This leads to small numerical differences. + +# Using legacy model format + +In all the checkpoint conversion examples used in this document, the saver format `--saver core` is used, signifying that the newer (and recommended) Megatron GPT model class will be used. I.e.: + +- old class: `megatron.legacy.model.gpt_model.GPTModel` +- new class: `megatron.core.models.gpt.gpt_model.GPTModel` + +Using this new format is the recommended approach. However, if your use case requires using the older class (i.e., convert using `--saver legacy`), then when launching training or finetuning, the following args must be added: + +- `--use-legacy-models`: use the older model class +- `--ckpt-format torch`: use the `torch` checkpoint format, which is the only checkpoint format that is compatible with the legacy model format diff --git a/docs/source/api-guide/custom_fsdp.md b/docs/source/api-guide/custom_fsdp.md new file mode 100644 index 0000000000000000000000000000000000000000..5784ae020aea7aaa292104b7eb5b5eb232323a91 --- /dev/null +++ b/docs/source/api-guide/custom_fsdp.md @@ -0,0 +1,183 @@ +# MCore Custom Fully Sharded Data Parallel (FSDP) + +## How to use ? + +Add these flag to enable MCore custom FSDP. + +```bash +--use-custom-fsdp +--data-parallel-sharding-strategy optim_grads_params +--no-gradient-accumulation-fusion +--use-distributed-optimizer +``` + +## Key Features + +- **Sharding Strategy**: Efficiently shards optimizer states, gradients, and parameters to reduce memory consumption. +- **Communication and Computation Overlap**: Optimized to enable concurrent execution of communication and computation, enhancing overall efficiency. +- **Supports automatic mixed precision training**: Compatible with BF16 O1/O2/O3 recipes, as well as FP8 compute with FP32 parameters and FP8 parameter training, allowing for flexible precision configurations. +- **Tensor Parallelism (TP), Expert Parallelism (EP) and Context Parallelism (CP)**: Compatible with TP, EP and CP configurations, enabling efficient scaling of large language models. +- **Distributed Model Initialization with Meta Device**: Allows model initialization using meta device, followed by layer-by-layer initialization of distributed model weight buffers via the `Module.reset_parameters` API, facilitating the initialization of extremely large models. + +## Configuration Recommendations + +### 1. Disable `CUDA_MAX_CONNECTIONS` + +To ensure full parallelization of FSDP communication and computation, disable the CUDA_MAX_CONNECTIONS environment variable. This step avoids potential bubble in CUDA stream. (But it may slow down TP and CP to some extent.) + +```bash +unset CUDA_MAX_CONNECTIONS +``` + +### 2. Add `--calculate-per-token-loss` + +For gradients sharding mode optimization, include the `--calculate-per-token-loss` flag in your training script. This improves performance by reducing the frequency of gradient scaling, which is also a sizable drain on SM resources. + +## Design of Custom FSDP + +### 1. Overview + +The custom Fully Sharded Data Parallelism (FSDP) implementation in Megatron-Core is specifically designed to optimize memory consumption and performance for large language models. The core design principles include: + + - **Optimized for Large Language Models**: This custom FSDP implementation is tailored to efficiently scale with models containing billions of parameters, ensuring seamless execution and training of massive models. + - **Efficient Memory Consumption**: By strategically sharding optimizer states, gradients, and model parameters, the custom FSDP significantly reduces memory usage. This approach enables the training of models that would otherwise be too large to fit in memory. + - **Efficient Workflow & Overlapping Communication and Computation**: The implementation is engineered to minimize the number of communication steps required during training. It maximizes the overlap between communication and computation, thereby enhancing overall training efficiency and reducing latency. + - **Support for MCore's Efficient Training Methods**: The custom FSDP seamlessly integrates with Megatron-Core's advanced parallelism techniques, including tensor parallelism, expert parallelism and context parallelism. Additionally, it supports automatic mixed precision training, further optimizing training performance and efficiency. + +The design of Custom FSDP draws inspiration from PyTorch FSDP [Zhao, Yanli, et al.](https://arxiv.org/pdf/2304.11277) and MCore's distributed optimizer. The introduction to PyTorch FSDP is referenced here to clarify the underlying concepts of the custom FSDP design. + +> In DistributedDataParallel, (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. In DDP the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks. + +> When training with FSDP, the GPU memory footprint is smaller than when training with DDP across all workers. This makes the training of some very large models feasible by allowing larger models or batch sizes to fit on device. This comes with the cost of increased communication volume. The communication overhead is reduced by internal optimizations like overlapping communication and computation. + +![FSDP workflow](../images/custom_fsdp/FSDP_workflow.png) + +*Notice that the unit processed in workflow here is the “FSDP instance 1: N layers”, where an FSDP instance is the smallest FSDP processing unit (also a PyTorch module), which means that we can safely release this module weights after using it (executing the forward or backward of this module), and there will be no other computations computations relying on these weights. This capability is the foundation of FSDP's layer-by-layer execution and memory-saving strategy. An FSDP instance is also referred to as an **FSDP Unit**.* + +*It is worth noting that an FSDP instance can correspond to multiple FSDP parameter groups. These groups are separated by Data Parallel (DP) communication groups and the data type of the parameter or gradient. Consequently, an FSDP instance may require several parameter-gather tasks before execution (forward or backward). Each **FSDP parameter group** corresponds to one **Data Parallel Buffer** in custom FSDP.* + +At a high level FSDP works as follow: + +In constructor + - Shard model parameters and each rank only keeps its own shard + +In forward path + - Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit + - Run forward computation + - Discard parameter shards it has just collected + +In backward path + - Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit + - Run backward computation + - Run reduce_scatter to sync gradients + - Discard parameters. + +One way to view FSDP’s sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards. + +![FSDP Allreduce](../images/custom_fsdp/FSDP_Allreduce.png) + +### 2. Custom FSDP underlying data structure + +To implement the FSDP functionality described above, the custom FSDP is designed with the following Python classes and data structure: + +![MCore Custom FSDP Class Diagram](../images/custom_fsdp/MCore_Custom_FSDP_Class_Diagram.png) + +### 3. The custom FSDP interface: FullyShardedDataParallel + +The custom FSDP provides the same programming interface as PyTorch's DistributedDataParallel (DDP) as FullyShardedDataParallel (FSDP). For example, you can apply FSDP to models as follows: + +```python +# Initialize model and optimizer +ddp_config.use_custom_fsdp = True +ddp_config.data_parallel_sharding_strategy = "optim_grads_params" +model = GPTModel(transformer_config) +model = FullyShardedDataParallel( + transformer_config, + model, + ddp_config, + fsdp_unit_modules = [TransformerLayer, LanguageModelEmbedding], +) +optimizer = torch.optim.AdamW(model.parameters(), lr=lr) +optimizer = DistributedOptimizer(optimizer, [model], [model.param_and_grad_buffer]) + +# Training loop +def train_step(inputs, labels): + optimizer.zero_grad() + for mbs_input, mbs_label in zip(inputs, labels): + outputs = model(mbs_input) + loss = loss_fn(outputs, mbs_label) + loss.backward() + optimizer.step() + +# Save and load model and optimizer state dict +def model_and_optimizer_state_dict(): + state_dict = { + "model": model.sharded_state_dict(), + "optimizer": optimizer.sharded_state_dict(), + } + return state_dict + +def load_model_and_optimizer_state_dict(state_dict): + model.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optimizer"]) +``` + +**Key Notes:** + - You can configure which modules should be treated as FSDP units via the `fsdp_unit_modules` argument. This configuration is mandatory. + - The custom FSDP must be used with a distributed optimizer since it provides distributed checkpointing. + - The data-parallel communication group for parameters is not explicitly shown. Custom FSDP configures these groups as either DP (data-parallel) or EDP (expert data-parallel) based on parameter markings. + +#### 3.1 Initializing Models on the Meta Device + +For training particularly large models with FSDP, you can initialize the model on the meta device. Using PyTorch's `reset_parameters` API, you can initialize model weights layer by layer during the construction of the `ParamAndGradBuffer`. Most PyTorch native modules and TransformerEngine modules support this API (e.g., [PyTorch Linear](https://github.com/pytorch/pytorch/blob/v2.6.0/torch/nn/modules/linear.py#L114), [TE LayerNormLinear](https://github.com/NVIDIA/TransformerEngine/blob/release_v2.0/transformer_engine/pytorch/module/layernorm_linear.py#L1107)). + +```python +# Initialize model on meta device +with torch.device("meta"): + model = GPTModel(config) + +model = FullyShardedDataParallel( + transformer_config, + model, + ddp_config, + fsdp_unit_modules=[TransformerLayer, LanguageModelEmbedding], +) +``` + +**Important Considerations:** +1. *Custom Modules*: If your model contains custom modules, ensure they implement the `reset_parameters` API. Otherwise, you may need to force parameter initialization on a CUDA or CPU device. +2. *Tensor Initialization*: Be cautious of tensors created during model initialization without a specified device—they will default to the meta device. To avoid issues, explicitly specify the device for these tensors to ensure compatibility with this function. + +### 4. Interaction between Custom FSDP and Model Forward/Backward Propagation + +Custom FSDP implements Fully Sharded Data Parallelism (FSDP) through a series of module hooks, gradient hooks, or by adding functions between modules. This involves inserting communications and manipulating parameters and gradients during PyTorch's module forward or backward propagation. + +Module hooks summary: +- Module pre-forward hook(`module.register_forward_pre_hook`): This hook unshards model weights before the forward pass. In the case of an FSDP Unit Module, add a RegisterFSDPBackwardFunction function that will release the module's modes on backward propagation. +- Module post-forward hook(`module.register_forward_hook`): This hook is used to reshard model weights after the forward pass. +- Root module pre-backward hook(`root_module.register_full_backward_pre_hook`): This hook checks that all model parameters are resharded, in order to avoid unnecessary memory spikes. It also marks all modules as being in the `TrainingState.PRE_BACKWARD` state. +- Module pre-backward hook(`module.register_full_backward_pre_hook`): This hook is used to unshard the model weights before the backward pass. +- Gradient accumulation hook(`grad_acc.register_hook`): This hook is used to accumulate gradients and trigger the gradient reduction pipeline. + + +The gradient reduction pipeline maintains a map of gradients to FSDP parameter groups. If all gradients in an FSDP parameter group are ready, it launches a gradient reduction. Note that this assumes that the model's gradients are always generated in a certain order (reverse of `module.parameters()`), as otherwise, FSDP would maintain too many parameter group grad buffers, leading to excessive memory usage. + +#### 4.1 Optimized for Activation Recompute + +Using the activation recompute will cause the same module to execute the forward function first and then the backward function in the backward prop, which will cause model weights unshard twice and model weights reshard twice. If we can tell program that this is a forward + backward operation, we can just call unshard once and reshard once. + +To make this determination, we keep track of the model's state with training_state, `FORWARD`, `PRE_BACKWARD`, `POST_BACKWARD`, `IDLE`. It's worth noting that pre-backward hook act before pre-forward hook, and we'll let pre-backward hook execute the model weight unshard, and then mark the model as `PRE_BACKWARD`, and when pre-forward hook sees this marking it will not perform the unshard operation. Similarly, for model weight reshard duplicate, post-forward hook act before post-backward function, and checking for the `PRE_BACKWARD` flag in the post-forward hook will cancel the unshard. + +### 5. Memory Mechanisms and Features of Custom FSDP + +FSDP can fully distribute the model parameters, gradients, and optimizer states, and for mixed-precision training, it can also fully distribute the high-precision main weights. This is pretty much distributes all the memory except for the activation memory, but FSDP will also face some memory issues. + +FSDP frequently unshards and reshards model weights, which can lead to busy memory allocation and deallocation. This results in untimely tensor releases, causing memory spikes (or even out-of-memory errors), crashes of the PyTorch memory allocator cache, and a large number of `cudaMalloc` and `cudaFree` calls. These issues can significantly slow down the system. + +The problem of untimely tensor release can generally be addressed using the `tensor._typed_storage(). _resize_(0)` API, which immediately deallocates the storage's memory. Custom FSDP provides interfaces in `AllGatherPipeline` and `GradReducePipeline` to replace the temporary buffer memory allocator used for parameter gathering and gradient reduction with ` StorageResizeBasedBucketAllocator`. This replaces the tensor release operation with the `tensor._typed_storage(). _resize_(0)` API. + +The PyTorch memory allocator cache crash is a complex issue that occurs frequently when the actual memory usage approaches the GPU memory limit, leading to poor performance. This problem is challenging and can only be mitigated by avoiding frequent hits on the GPU memory limit. Using a self-managed memory allocator like ` RotaryBucketAllocator` is another potential solution. However, note that `RotaryBucketAllocator` is not yet mature. + +## References + +- [Getting Started with Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) diff --git a/docs/source/api-guide/index.rst b/docs/source/api-guide/index.rst index dac785af04a303f8ee44179a50a18c14b4cb556d..5212fdcf4bba886d7a438cf509e097f6eb8cf374 100644 --- a/docs/source/api-guide/index.rst +++ b/docs/source/api-guide/index.rst @@ -1,20 +1,23 @@ -API Guide -========= - -.. toctree:: - :maxdepth: 4 - - models - tensor_parallel - context_parallel - pipeline_parallel - fusions - transformer - moe - dist_checkpointing - dist_optimizer - distributed - datasets - num_microbatches_calculator - optimizer_param_scheduler +API Guide +========= + +.. toctree:: + :maxdepth: 4 + + models + tensor_parallel + context_parallel + pipeline_parallel + custom_fsdp + fusions + transformer + moe + dist_checkpointing + dist_optimizer + distributed + datasets + multi_latent_attention + num_microbatches_calculator + optimizer_param_scheduler + optimizer_cpu_offload encoder_decoder_parallelism \ No newline at end of file diff --git a/docs/source/api-guide/multi_latent_attention.rst b/docs/source/api-guide/multi_latent_attention.rst new file mode 100644 index 0000000000000000000000000000000000000000..c4d043a4e104959ee018d38e600e7e8bb0f6e6b5 --- /dev/null +++ b/docs/source/api-guide/multi_latent_attention.rst @@ -0,0 +1,14 @@ +Multi-Latent Attention +====================== + +Multi-Latent Attention overview +------------------------------- + +Multi-Latent Attention ("MLA") is an innovative attention mechanism introduced by Deepseek team that enhances the efficiency of attention computation by leveraging multiple latent spaces. This approach is particularly beneficial for large language models (LLMs), as it reduces the computational burden associated with traditional attention mechanisms. According to Deepseek-V2 technical report, MLA achieves better performance compared to Multi-Head Attention (MHA) and requires smaller KV cache. + +Enabling Multi-Latent Attention +------------------------------- + +To enable MLA in Megatron-LM, set the following flags in command line: +- `--multi-latent-attention` to enable MLA in MLP. +- Set `MLATransformerConfig` to configure MLA. diff --git a/docs/source/api-guide/optimizer_cpu_offload.rst b/docs/source/api-guide/optimizer_cpu_offload.rst new file mode 100644 index 0000000000000000000000000000000000000000..7d8099449584c29c56b81b8ea8f1894824531702 --- /dev/null +++ b/docs/source/api-guide/optimizer_cpu_offload.rst @@ -0,0 +1,4 @@ +Optimizer CPU offload package +============================== + +.. mdinclude :: ../../../megatron/core/optimizer/cpu_offloading/README.md diff --git a/docs/source/images/custom_fsdp/FSDP_Allreduce.png b/docs/source/images/custom_fsdp/FSDP_Allreduce.png new file mode 100644 index 0000000000000000000000000000000000000000..66e2391ed0415b5919f53b3510d25d53d2a6b728 Binary files /dev/null and b/docs/source/images/custom_fsdp/FSDP_Allreduce.png differ diff --git a/docs/source/images/custom_fsdp/FSDP_workflow.png b/docs/source/images/custom_fsdp/FSDP_workflow.png new file mode 100644 index 0000000000000000000000000000000000000000..588b6f220a375163b34a9ae17254b42678e445b8 Binary files /dev/null and b/docs/source/images/custom_fsdp/FSDP_workflow.png differ diff --git a/docs/source/images/custom_fsdp/MCore_Custom_FSDP_Class_Diagram.png b/docs/source/images/custom_fsdp/MCore_Custom_FSDP_Class_Diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..f9603079b928fb0c05ad077e654a72725844faf3 Binary files /dev/null and b/docs/source/images/custom_fsdp/MCore_Custom_FSDP_Class_Diagram.png differ diff --git a/examples/academic_paper_scripts/detxoify_lm/finetune_gpt_distributed-1.3b.sh b/examples/academic_paper_scripts/detxoify_lm/finetune_gpt_distributed-1.3b.sh old mode 100644 new mode 100755 index a212fbdf3f6cef5a88a2faab8e229158fdf883b4..3f784ebe8a4536cc91544f5922d3ec0c33400585 --- a/examples/academic_paper_scripts/detxoify_lm/finetune_gpt_distributed-1.3b.sh +++ b/examples/academic_paper_scripts/detxoify_lm/finetune_gpt_distributed-1.3b.sh @@ -1,63 +1,63 @@ -#! /bin/bash - -# Change for multinode config -GPUS_PER_NODE=16 -MASTER_ADDR=localhost -MASTER_PORT=$(($RANDOM + 1024)) -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -# input -DATA_PATH=$1 -SHARE_DATA=$PWD # current work dir -FINETUNED_PATH="$SHARE_DATA/$2" -lr=$3 -bs=$4 -iter=$5 -CHECKPOINT_PATH=$6 - -# vocab -VOCAB_FILE=gpt2-vocab.json # Your gpt-2 vocab -MERGE_FILE=gpt2-merges.txt # Your gpt-2 merge file - -# tensorboard -TENSORBOARD_DIR="$SHARE_DATA/tensorboard/$2" -mkdir -p ${TENSORBOARD_DIR} - -DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" - -python -m torch.distributed.run $DISTRIBUTED_ARGS \ - examples/detxoify_lm/finetune_gpt.py \ - --num-layers 24 \ - --hidden-size 2048 \ - --num-attention-heads 32 \ - --micro-batch-size 4 \ - --global-batch-size $bs \ - --seq-length 2048 \ - --max-position-embeddings 2048 \ - --train-iters $iter \ - --save $FINETUNED_PATH \ - --load $CHECKPOINT_PATH \ - --data-path $DATA_PATH \ - --data-path2 ${DATA_BLEND} \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --split 100,0,0 \ - --distributed-backend nccl \ - --lr-decay-style constant \ - --lr $lr \ - --clip-grad 1.0 \ - --weight-decay 0.1 \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --checkpoint-activations \ - --log-interval 1 \ - --save-interval 78 \ - --eval-interval 78 \ - --eval-iters 50 \ - --fp16 \ - --DDP-impl local \ - --finetune --no-load-optim \ - --log-validation-ppl-to-tensorboard \ - --tensorboard-dir ${TENSORBOARD_DIR} +#! /bin/bash + +# Change for multinode config +GPUS_PER_NODE=16 +MASTER_ADDR=localhost +MASTER_PORT=$(($RANDOM + 1024)) +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +# input +DATA_PATH=$1 +SHARE_DATA=$PWD # current work dir +FINETUNED_PATH="$SHARE_DATA/$2" +lr=$3 +bs=$4 +iter=$5 +CHECKPOINT_PATH=$6 + +# vocab +VOCAB_FILE=gpt2-vocab.json # Your gpt-2 vocab +MERGE_FILE=gpt2-merges.txt # Your gpt-2 merge file + +# tensorboard +TENSORBOARD_DIR="$SHARE_DATA/tensorboard/$2" +mkdir -p ${TENSORBOARD_DIR} + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.run $DISTRIBUTED_ARGS \ + examples/detxoify_lm/finetune_gpt.py \ + --num-layers 24 \ + --hidden-size 2048 \ + --num-attention-heads 32 \ + --micro-batch-size 4 \ + --global-batch-size $bs \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters $iter \ + --save $FINETUNED_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --data-path2 ${DATA_BLEND} \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --split 100,0,0 \ + --distributed-backend nccl \ + --lr-decay-style constant \ + --lr $lr \ + --clip-grad 1.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --checkpoint-activations \ + --log-interval 1 \ + --save-interval 78 \ + --eval-interval 78 \ + --eval-iters 50 \ + --fp16 \ + --DDP-impl local \ + --finetune --no-load-optim \ + --log-validation-ppl-to-tensorboard \ + --tensorboard-dir ${TENSORBOARD_DIR} diff --git a/examples/academic_paper_scripts/detxoify_lm/generate-1.3b.sh b/examples/academic_paper_scripts/detxoify_lm/generate-1.3b.sh old mode 100644 new mode 100755 diff --git a/examples/academic_paper_scripts/detxoify_lm/self_generation/selfgenerate-1.3b-unconditional.sh b/examples/academic_paper_scripts/detxoify_lm/self_generation/selfgenerate-1.3b-unconditional.sh old mode 100644 new mode 100755 diff --git a/examples/academic_paper_scripts/msdp/data_processing.sh b/examples/academic_paper_scripts/msdp/data_processing.sh old mode 100644 new mode 100755 diff --git a/examples/academic_paper_scripts/msdp/eval_knwl_generation.sh b/examples/academic_paper_scripts/msdp/eval_knwl_generation.sh old mode 100644 new mode 100755 diff --git a/examples/academic_paper_scripts/msdp/eval_resp_generation.sh b/examples/academic_paper_scripts/msdp/eval_resp_generation.sh old mode 100644 new mode 100755 diff --git a/examples/academic_paper_scripts/msdp/prep_resp_gen.sh b/examples/academic_paper_scripts/msdp/prep_resp_gen.sh old mode 100644 new mode 100755 diff --git a/examples/academic_paper_scripts/msdp/prompt_knwl_gen.sh b/examples/academic_paper_scripts/msdp/prompt_knwl_gen.sh old mode 100644 new mode 100755 diff --git a/examples/academic_paper_scripts/msdp/prompt_resp_gen.sh b/examples/academic_paper_scripts/msdp/prompt_resp_gen.sh old mode 100644 new mode 100755 diff --git a/examples/academic_paper_scripts/sc21/CONFIG.sh b/examples/academic_paper_scripts/sc21/CONFIG.sh old mode 100644 new mode 100755 index f17ccd7b023ca9aeb538ba38a60808e44418873b..180c1fee82124c0b6906aeb4c61d04a24bd0f6fa --- a/examples/academic_paper_scripts/sc21/CONFIG.sh +++ b/examples/academic_paper_scripts/sc21/CONFIG.sh @@ -1,57 +1,57 @@ -#!/bin/bash - - -# SLURM options. -export SLURM_PARTITION= -export SLURM_ACCOUNT= - - -# Source code. -export MEGATRON_CODE_DIR= - - -# This variable is used to mount the relevant part of the filesystem -# inside the docker container. Note that the `MEGATRON_CODE_DIR` and the -# launch directory already get mounted; this variable should be used to -# mount the directories that contain the data and tokenizer files. -export DOCKER_MOUNT_DIR= - - -# Data and tokenizer files. -MEGATRON_DATA= -BPE_VOCAB_FILE= -BPE_MERGE_FILE= - - -# Megatron input parameters. -# `MEGATRON_EXTRA_PARAMS` can be used to provide any extra parameters -# that are not listed here. -export MEGATRON_PARAMS=" ${MEGATRON_EXTRA_PARAMS} \ - --tensor-model-parallel-size ${TP} \ - --pipeline-model-parallel-size ${PP} \ - --micro-batch-size ${MBS} \ - --global-batch-size ${GBS} \ - --num-layers ${NLS} \ - --hidden-size ${HS} \ - --num-attention-heads ${NAH} \ - --DDP-impl ${DDP} \ - --data-path ${MEGATRON_DATA} \ - --vocab-file ${BPE_VOCAB_FILE} \ - --merge-file ${BPE_MERGE_FILE} \ - --log-interval 5 \ - --seq-length 2048 \ - --max-position-embeddings 2048 \ - --train-iters 500 \ - --lr-decay-iters 320 \ - --lr 0.0001 \ - --min-lr 0.00001 \ - --lr-decay-style cosine \ - --lr-warmup-fraction 0.01 \ - --split 969,30,1 \ - --eval-iters 100 \ - --eval-interval 1000 \ - --clip-grad 1.0 \ - --fp16 \ - --loss-scale 8192 " - - +#!/bin/bash + + +# SLURM options. +export SLURM_PARTITION= +export SLURM_ACCOUNT= + + +# Source code. +export MEGATRON_CODE_DIR= + + +# This variable is used to mount the relevant part of the filesystem +# inside the docker container. Note that the `MEGATRON_CODE_DIR` and the +# launch directory already get mounted; this variable should be used to +# mount the directories that contain the data and tokenizer files. +export DOCKER_MOUNT_DIR= + + +# Data and tokenizer files. +MEGATRON_DATA= +BPE_VOCAB_FILE= +BPE_MERGE_FILE= + + +# Megatron input parameters. +# `MEGATRON_EXTRA_PARAMS` can be used to provide any extra parameters +# that are not listed here. +export MEGATRON_PARAMS=" ${MEGATRON_EXTRA_PARAMS} \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --micro-batch-size ${MBS} \ + --global-batch-size ${GBS} \ + --num-layers ${NLS} \ + --hidden-size ${HS} \ + --num-attention-heads ${NAH} \ + --DDP-impl ${DDP} \ + --data-path ${MEGATRON_DATA} \ + --vocab-file ${BPE_VOCAB_FILE} \ + --merge-file ${BPE_MERGE_FILE} \ + --log-interval 5 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500 \ + --lr-decay-iters 320 \ + --lr 0.0001 \ + --min-lr 0.00001 \ + --lr-decay-style cosine \ + --lr-warmup-fraction 0.01 \ + --split 969,30,1 \ + --eval-iters 100 \ + --eval-interval 1000 \ + --clip-grad 1.0 \ + --fp16 \ + --loss-scale 8192 " + + diff --git a/examples/academic_paper_scripts/sc21/SBATCH.sh b/examples/academic_paper_scripts/sc21/SBATCH.sh old mode 100644 new mode 100755 index 95431b9b7e780bbdd4b18593546356aad02945b1..4516a249db16d9832e76ccce4245fa9e1245f2c0 --- a/examples/academic_paper_scripts/sc21/SBATCH.sh +++ b/examples/academic_paper_scripts/sc21/SBATCH.sh @@ -1,13 +1,13 @@ -#!/bin/bash - - -sbatch -p ${SLURM_PARTITION} \ - -A ${SLURM_ACCOUNT} \ - --job-name=${JOB_NAME} \ - --nodes=${NNODES} \ - --export=MEGATRON_CODE_DIR,MEGATRON_PARAMS,DOCKER_MOUNT_DIR SRUN.sh - -exit 0 - - - +#!/bin/bash + + +sbatch -p ${SLURM_PARTITION} \ + -A ${SLURM_ACCOUNT} \ + --job-name=${JOB_NAME} \ + --nodes=${NNODES} \ + --export=MEGATRON_CODE_DIR,MEGATRON_PARAMS,DOCKER_MOUNT_DIR SRUN.sh + +exit 0 + + + diff --git a/examples/academic_paper_scripts/sc21/SRUN.sh b/examples/academic_paper_scripts/sc21/SRUN.sh old mode 100644 new mode 100755 index 52a9aff0c1294acb1e5527faad4f73fe5e027e21..717ff53b89ee92fc3a7e0742c720376e30cad1ac --- a/examples/academic_paper_scripts/sc21/SRUN.sh +++ b/examples/academic_paper_scripts/sc21/SRUN.sh @@ -1,18 +1,18 @@ -#!/bin/bash - -#SBATCH -t 0:30:00 --exclusive --mem=0 --overcommit --ntasks-per-node=8 - - -THIS_DIR=`pwd` -DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` -mkdir -p ${THIS_DIR}/logs - - -CMD="python -u ${MEGATRON_CODE_DIR}/pretrain_gpt.py ${MEGATRON_PARAMS}" - - -srun -l \ - --container-image "nvcr.io#nvidia/pytorch:20.12-py3" \ - --container-mounts "${THIS_DIR}:${THIS_DIR},${MEGATRON_CODE_DIR}:${MEGATRON_CODE_DIR},${DOCKER_MOUNT_DIR}:${DOCKER_MOUNT_DIR}" \ - --output=${THIS_DIR}/logs/%x_%j_$DATETIME.log sh -c "${CMD}" - +#!/bin/bash + +#SBATCH -t 0:30:00 --exclusive --mem=0 --overcommit --ntasks-per-node=8 + + +THIS_DIR=`pwd` +DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` +mkdir -p ${THIS_DIR}/logs + + +CMD="python -u ${MEGATRON_CODE_DIR}/pretrain_gpt.py ${MEGATRON_PARAMS}" + + +srun -l \ + --container-image "nvcr.io#nvidia/pytorch:20.12-py3" \ + --container-mounts "${THIS_DIR}:${THIS_DIR},${MEGATRON_CODE_DIR}:${MEGATRON_CODE_DIR},${DOCKER_MOUNT_DIR}:${DOCKER_MOUNT_DIR}" \ + --output=${THIS_DIR}/logs/%x_%j_$DATETIME.log sh -c "${CMD}" + diff --git a/examples/academic_paper_scripts/sc21/run_figure_11.sh b/examples/academic_paper_scripts/sc21/run_figure_11.sh old mode 100644 new mode 100755 index 2ec7d9eb31e50e01e3d5dab6978a71deffd247aa..ff0594afc8c560c10171684a0e845367d5f48b1f --- a/examples/academic_paper_scripts/sc21/run_figure_11.sh +++ b/examples/academic_paper_scripts/sc21/run_figure_11.sh @@ -1,46 +1,46 @@ -#!/bin/bash - -# ================================ -# Choose the case to run. -# ================================ - -# Pipeline-parallel size options = [1, 2, 4, 8]. -PP=1 - -# Batch size (global batch size) options = [8, 128]. -GBS=8 - - - - - -# Set pipeline-parallel size options. -NLS=$((3*PP)) -NNODES=${PP} - - -# Other params. -TP=8 -MBS=1 -HS=20480 -NAH=128 -DDP=local -MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " - - -# Name of the job. -export JOB_NAME=results_figure_11_pipeline_parallel_size_${PP}_batch_size_${GBS} - - -# Import the configs. -. `pwd`/CONFIG.sh - - -# Submit the job. -. `pwd`/SBATCH.sh - - -exit 0 - - - +#!/bin/bash + +# ================================ +# Choose the case to run. +# ================================ + +# Pipeline-parallel size options = [1, 2, 4, 8]. +PP=1 + +# Batch size (global batch size) options = [8, 128]. +GBS=8 + + + + + +# Set pipeline-parallel size options. +NLS=$((3*PP)) +NNODES=${PP} + + +# Other params. +TP=8 +MBS=1 +HS=20480 +NAH=128 +DDP=local +MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " + + +# Name of the job. +export JOB_NAME=results_figure_11_pipeline_parallel_size_${PP}_batch_size_${GBS} + + +# Import the configs. +. `pwd`/CONFIG.sh + + +# Submit the job. +. `pwd`/SBATCH.sh + + +exit 0 + + + diff --git a/examples/academic_paper_scripts/sc21/run_figure_12.sh b/examples/academic_paper_scripts/sc21/run_figure_12.sh old mode 100644 new mode 100755 index 11e550854de4cd576d9625ca9dd5330d44fffb76..df06eb55e2de77fcef6c22c0567ff52376cab4fd --- a/examples/academic_paper_scripts/sc21/run_figure_12.sh +++ b/examples/academic_paper_scripts/sc21/run_figure_12.sh @@ -1,54 +1,54 @@ -#!/bin/bash - -# ================================ -# Choose the case to run. -# ================================ - -# Interleaved schedule options = [YES, NO]. -INTERLEAVED=YES - -# Batch size (global batch size) options = [12, 24, 36, ..., 60]. -GBS=12 - - - - - -# Set interleaved schedule options. -if [ ${INTERLEAVED} == "YES" ]; then - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 " -elif [ ${INTERLEAVED} == "NO" ]; then - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -else - echo "Invalid configuration" - exit 1 -fi - - -# Other params. -TP=8 -PP=12 -MBS=1 -NLS=96 -HS=12288 -NAH=96 -DDP=local -NNODES=12 - - -# Name of the job. -export JOB_NAME=results_figure_12_interleaved_${INTERLEAVED}_batch_size_${GBS} - - -# Import the configs. -. `pwd`/CONFIG.sh - - -# Submit the job. -. `pwd`/SBATCH.sh - - -exit 0 - - - +#!/bin/bash + +# ================================ +# Choose the case to run. +# ================================ + +# Interleaved schedule options = [YES, NO]. +INTERLEAVED=YES + +# Batch size (global batch size) options = [12, 24, 36, ..., 60]. +GBS=12 + + + + + +# Set interleaved schedule options. +if [ ${INTERLEAVED} == "YES" ]; then + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 " +elif [ ${INTERLEAVED} == "NO" ]; then + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +else + echo "Invalid configuration" + exit 1 +fi + + +# Other params. +TP=8 +PP=12 +MBS=1 +NLS=96 +HS=12288 +NAH=96 +DDP=local +NNODES=12 + + +# Name of the job. +export JOB_NAME=results_figure_12_interleaved_${INTERLEAVED}_batch_size_${GBS} + + +# Import the configs. +. `pwd`/CONFIG.sh + + +# Submit the job. +. `pwd`/SBATCH.sh + + +exit 0 + + + diff --git a/examples/academic_paper_scripts/sc21/run_figure_13.sh b/examples/academic_paper_scripts/sc21/run_figure_13.sh old mode 100644 new mode 100755 index 7ba560e87b253fb63192866d3089c3d967f086e6..2c75c6081f62088eb84bf275e28ab62af3cb203e --- a/examples/academic_paper_scripts/sc21/run_figure_13.sh +++ b/examples/academic_paper_scripts/sc21/run_figure_13.sh @@ -1,46 +1,46 @@ -#!/bin/bash - -# ================================ -# Choose the case to run. -# ================================ - -# Pipeline-parallel size options = [2, 4, 8, 16, 32]. -PP=2 - -# Batch size (global batch size) options = [32, 128]. -GBS=32 - - - - - -# Set pipeline-parallel and tensor-parallel size options. -TP=$((64/PP)) - - -# Other params. -MBS=1 -NLS=32 -HS=20480 -NAH=128 -DDP=local -MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -NNODES=8 - - -# Name of the job. -export JOB_NAME=results_figure_13_pipeline_parallel_size_${PP}_tensor_parallel_size_${TP}_batch_size_${GBS} - - -# Import the configs. -. `pwd`/CONFIG.sh - - -# Submit the job. -. `pwd`/SBATCH.sh - - -exit 0 - - - +#!/bin/bash + +# ================================ +# Choose the case to run. +# ================================ + +# Pipeline-parallel size options = [2, 4, 8, 16, 32]. +PP=2 + +# Batch size (global batch size) options = [32, 128]. +GBS=32 + + + + + +# Set pipeline-parallel and tensor-parallel size options. +TP=$((64/PP)) + + +# Other params. +MBS=1 +NLS=32 +HS=20480 +NAH=128 +DDP=local +MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +NNODES=8 + + +# Name of the job. +export JOB_NAME=results_figure_13_pipeline_parallel_size_${PP}_tensor_parallel_size_${TP}_batch_size_${GBS} + + +# Import the configs. +. `pwd`/CONFIG.sh + + +# Submit the job. +. `pwd`/SBATCH.sh + + +exit 0 + + + diff --git a/examples/academic_paper_scripts/sc21/run_figure_14.sh b/examples/academic_paper_scripts/sc21/run_figure_14.sh old mode 100644 new mode 100755 index 4b83879c4bb71546a7fb5bac365491efd96d3049..87ac082d79c8b797aab3513a2c6852aab26c8ef3 --- a/examples/academic_paper_scripts/sc21/run_figure_14.sh +++ b/examples/academic_paper_scripts/sc21/run_figure_14.sh @@ -1,47 +1,47 @@ -#!/bin/bash - -# ================================ -# Choose the case to run. -# ================================ - -# Pipeline-parallel size options = [2, 4, 8, 16, 32]. -PP=2 - -# Batch size (global batch size) options = [32, 512]. -GBS=32 - - - - - -# Set pipeline-parallel and data-parallel size options. -DP=$((64/PP)) - - -# Other params. -TP=1 -MBS=1 -NLS=32 -HS=3840 -NAH=32 -DDP=local -MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -NNODES=8 - - -# Name of the job. -export JOB_NAME=results_figure_14_pipeline_parallel_size_${PP}_data_parallel_size_${DP}_batch_size_${GBS} - - -# Import the configs. -. `pwd`/CONFIG.sh - - -# Submit the job. -. `pwd`/SBATCH.sh - - -exit 0 - - - +#!/bin/bash + +# ================================ +# Choose the case to run. +# ================================ + +# Pipeline-parallel size options = [2, 4, 8, 16, 32]. +PP=2 + +# Batch size (global batch size) options = [32, 512]. +GBS=32 + + + + + +# Set pipeline-parallel and data-parallel size options. +DP=$((64/PP)) + + +# Other params. +TP=1 +MBS=1 +NLS=32 +HS=3840 +NAH=32 +DDP=local +MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +NNODES=8 + + +# Name of the job. +export JOB_NAME=results_figure_14_pipeline_parallel_size_${PP}_data_parallel_size_${DP}_batch_size_${GBS} + + +# Import the configs. +. `pwd`/CONFIG.sh + + +# Submit the job. +. `pwd`/SBATCH.sh + + +exit 0 + + + diff --git a/examples/academic_paper_scripts/sc21/run_figure_15.sh b/examples/academic_paper_scripts/sc21/run_figure_15.sh old mode 100644 new mode 100755 index 547ad1de6fb091ca5f922e2b48559ceadffa7ce8..f47150f78907be49811647ff171e11736191477e --- a/examples/academic_paper_scripts/sc21/run_figure_15.sh +++ b/examples/academic_paper_scripts/sc21/run_figure_15.sh @@ -1,47 +1,47 @@ -#!/bin/bash - -# ================================ -# Choose the case to run. -# ================================ - -# Tensor-parallel size options = [2, 4, 8, 16, 32]. -TP=2 - -# Batch size (global batch size) options = [32, 128, 512]. -GBS=32 - - - - - -# Set tensor-parallel and data-parallel size options. -DP=$((64/TP)) - - -# Other params. -PP=1 -MBS=1 -NLS=32 -HS=3840 -NAH=32 -DDP=local -MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -NNODES=8 - - -# Name of the job. -export JOB_NAME=results_figure_15_tensor_parallel_size_${TP}_data_parallel_size_${DP}_batch_size_${GBS} - - -# Import the configs. -. `pwd`/CONFIG.sh - - -# Submit the job. -. `pwd`/SBATCH.sh - - -exit 0 - - - +#!/bin/bash + +# ================================ +# Choose the case to run. +# ================================ + +# Tensor-parallel size options = [2, 4, 8, 16, 32]. +TP=2 + +# Batch size (global batch size) options = [32, 128, 512]. +GBS=32 + + + + + +# Set tensor-parallel and data-parallel size options. +DP=$((64/TP)) + + +# Other params. +PP=1 +MBS=1 +NLS=32 +HS=3840 +NAH=32 +DDP=local +MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +NNODES=8 + + +# Name of the job. +export JOB_NAME=results_figure_15_tensor_parallel_size_${TP}_data_parallel_size_${DP}_batch_size_${GBS} + + +# Import the configs. +. `pwd`/CONFIG.sh + + +# Submit the job. +. `pwd`/SBATCH.sh + + +exit 0 + + + diff --git a/examples/academic_paper_scripts/sc21/run_figure_16.sh b/examples/academic_paper_scripts/sc21/run_figure_16.sh old mode 100644 new mode 100755 index 8c353a3e7623262baf9dc6c24554e9ab4dce26e7..7f612574be575ef6cd883d0b29321b9581a0df18 --- a/examples/academic_paper_scripts/sc21/run_figure_16.sh +++ b/examples/academic_paper_scripts/sc21/run_figure_16.sh @@ -1,43 +1,43 @@ -#!/bin/bash - -# ================================ -# Choose the case to run. -# ================================ - -# Microbatch size options = [1, 2, 4, 8]. -MBS=1 - -# Batch size (global batch size) options = [128, 512]. -GBS=128 - - - - - -# Other params. -TP=8 -PP=8 -NLS=32 -HS=15360 -NAH=128 -DDP=local -MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -NNODES=8 - - -# Name of the job. -export JOB_NAME=results_figure_16_microbatch_size_${MBS}_batch_size_${GBS} - - -# Import the configs. -. `pwd`/CONFIG.sh - - -# Submit the job. -. `pwd`/SBATCH.sh - - -exit 0 - - - +#!/bin/bash + +# ================================ +# Choose the case to run. +# ================================ + +# Microbatch size options = [1, 2, 4, 8]. +MBS=1 + +# Batch size (global batch size) options = [128, 512]. +GBS=128 + + + + + +# Other params. +TP=8 +PP=8 +NLS=32 +HS=15360 +NAH=128 +DDP=local +MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +NNODES=8 + + +# Name of the job. +export JOB_NAME=results_figure_16_microbatch_size_${MBS}_batch_size_${GBS} + + +# Import the configs. +. `pwd`/CONFIG.sh + + +# Submit the job. +. `pwd`/SBATCH.sh + + +exit 0 + + + diff --git a/examples/academic_paper_scripts/sc21/run_figure_17.sh b/examples/academic_paper_scripts/sc21/run_figure_17.sh old mode 100644 new mode 100755 index d6899b321d6c11238af3b12da3690c8c3d46be34..6da59de4e5c286881a26bcf75ed11460ee8d7faa --- a/examples/academic_paper_scripts/sc21/run_figure_17.sh +++ b/examples/academic_paper_scripts/sc21/run_figure_17.sh @@ -1,54 +1,54 @@ -#!/bin/bash - -# ================================ -# Choose the case to run. -# ================================ - -# Activation recomputation options = [YES, NO]. -ACTIVATION_RECOMPUTATION=YES - -# Batch size (global batch size) options = [1, 2, 4, ..., 256]. -GBS=1 - - - - - -# Set activation recomputation. -if [ ${ACTIVATION_RECOMPUTATION} == "YES" ]; then - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -elif [ ${ACTIVATION_RECOMPUTATION} == "NO" ]; then - MEGATRON_EXTRA_PARAMS="" -else - echo "Invalid configuration" - exit 1 -fi - - -# Other params. -TP=8 -PP=16 -MBS=1 -NLS=80 -HS=12288 -NAH=96 -DDP=local -NNODES=16 - - -# Name of the job. -export JOB_NAME=results_figure_17_activation_recomputation_${ACTIVATION_RECOMPUTATION}_batch_size_${GBS} - - -# Import the configs. -. `pwd`/CONFIG.sh - - -# Submit the job. -. `pwd`/SBATCH.sh - - -exit 0 - - - +#!/bin/bash + +# ================================ +# Choose the case to run. +# ================================ + +# Activation recomputation options = [YES, NO]. +ACTIVATION_RECOMPUTATION=YES + +# Batch size (global batch size) options = [1, 2, 4, ..., 256]. +GBS=1 + + + + + +# Set activation recomputation. +if [ ${ACTIVATION_RECOMPUTATION} == "YES" ]; then + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +elif [ ${ACTIVATION_RECOMPUTATION} == "NO" ]; then + MEGATRON_EXTRA_PARAMS="" +else + echo "Invalid configuration" + exit 1 +fi + + +# Other params. +TP=8 +PP=16 +MBS=1 +NLS=80 +HS=12288 +NAH=96 +DDP=local +NNODES=16 + + +# Name of the job. +export JOB_NAME=results_figure_17_activation_recomputation_${ACTIVATION_RECOMPUTATION}_batch_size_${GBS} + + +# Import the configs. +. `pwd`/CONFIG.sh + + +# Submit the job. +. `pwd`/SBATCH.sh + + +exit 0 + + + diff --git a/examples/academic_paper_scripts/sc21/run_figure_18.sh b/examples/academic_paper_scripts/sc21/run_figure_18.sh old mode 100644 new mode 100755 index 88924fb820be4767ed6aa00633682ece581329db..0ddd8a27eea02575989f175bdf05cc21b59d5234 --- a/examples/academic_paper_scripts/sc21/run_figure_18.sh +++ b/examples/academic_paper_scripts/sc21/run_figure_18.sh @@ -1,54 +1,54 @@ -#!/bin/bash - -# ================================ -# Choose the case to run. -# ================================ - -# Scatter-gather communication optimization options = [YES, NO]. -SCATTER_GATHER=YES - -# Batch size (global batch size) options = [12, 24, 36, ..., 60]. -GBS=12 - - - - - -# Set scatter-gather communication optimization options. -if [ ${SCATTER_GATHER} == "YES" ]; then - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 " -elif [ ${SCATTER_GATHER} == "NO" ]; then - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 --no-scatter-gather-tensors-in-pipeline " -else - echo "Invalid configuration" - exit 1 -fi - - -# Other params. -TP=8 -PP=12 -MBS=1 -NLS=96 -HS=12288 -NAH=96 -DDP=local -NNODES=12 - - -# Name of the job. -export JOB_NAME=results_figure_18_scatter_gather_${SCATTER_GATHER}_batch_size_${GBS} - - -# Import the configs. -. `pwd`/CONFIG.sh - - -# Submit the job. -. `pwd`/SBATCH.sh - - -exit 0 - - - +#!/bin/bash + +# ================================ +# Choose the case to run. +# ================================ + +# Scatter-gather communication optimization options = [YES, NO]. +SCATTER_GATHER=YES + +# Batch size (global batch size) options = [12, 24, 36, ..., 60]. +GBS=12 + + + + + +# Set scatter-gather communication optimization options. +if [ ${SCATTER_GATHER} == "YES" ]; then + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 " +elif [ ${SCATTER_GATHER} == "NO" ]; then + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 --no-scatter-gather-tensors-in-pipeline " +else + echo "Invalid configuration" + exit 1 +fi + + +# Other params. +TP=8 +PP=12 +MBS=1 +NLS=96 +HS=12288 +NAH=96 +DDP=local +NNODES=12 + + +# Name of the job. +export JOB_NAME=results_figure_18_scatter_gather_${SCATTER_GATHER}_batch_size_${GBS} + + +# Import the configs. +. `pwd`/CONFIG.sh + + +# Submit the job. +. `pwd`/SBATCH.sh + + +exit 0 + + + diff --git a/examples/academic_paper_scripts/sc21/run_table_1.sh b/examples/academic_paper_scripts/sc21/run_table_1.sh old mode 100644 new mode 100755 index 1b15fb04582c90dc47fb1bbd3aca46feca2585ba..31884ccfd830afde8cf69c8de2d4c3c84aff1935 --- a/examples/academic_paper_scripts/sc21/run_table_1.sh +++ b/examples/academic_paper_scripts/sc21/run_table_1.sh @@ -1,145 +1,145 @@ -#!/bin/bash - -# ================================ -# Choose the case to run. -# ================================ -# model size options = [1.7B, 3.6B, 7.5B, 18B, 39B, 76B, 145B, 310B, 530B, 1T] -MODEL_SIZE=1.7B - - - - - - -if [ ${MODEL_SIZE} == "1.7B" ]; then - TP=1 - PP=1 - MBS=16 - GBS=512 - NLS=24 - HS=2304 - NAH=24 - DDP=torch - NNODES=4 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -elif [ ${MODEL_SIZE} == "3.6B" ]; then - TP=2 - PP=1 - MBS=16 - GBS=512 - NLS=30 - HS=3072 - NAH=32 - DDP=torch - NNODES=8 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -elif [ ${MODEL_SIZE} == "7.5B" ]; then - TP=4 - PP=1 - MBS=16 - GBS=512 - NLS=36 - HS=4096 - NAH=32 - DDP=torch - NNODES=16 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -elif [ ${MODEL_SIZE} == "18B" ]; then - TP=8 - PP=1 - MBS=8 - GBS=1024 - NLS=40 - HS=6144 - NAH=48 - DDP=torch - NNODES=32 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -elif [ ${MODEL_SIZE} == "39B" ]; then - TP=8 - PP=2 - MBS=4 - GBS=1536 - NLS=48 - HS=8192 - NAH=64 - DDP=local - NNODES=64 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -elif [ ${MODEL_SIZE} == "76B" ]; then - TP=8 - PP=4 - MBS=2 - GBS=1792 - NLS=60 - HS=10240 - NAH=80 - DDP=local - NNODES=128 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 5" -elif [ ${MODEL_SIZE} == "145B" ]; then - TP=8 - PP=8 - MBS=2 - GBS=2304 - NLS=80 - HS=12288 - NAH=96 - DDP=local - NNODES=192 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 5 " -elif [ ${MODEL_SIZE} == "310B" ]; then - TP=8 - PP=16 - MBS=1 - GBS=2160 - NLS=96 - HS=16384 - NAH=128 - DDP=local - NNODES=240 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 3 " -elif [ ${MODEL_SIZE} == "530B" ]; then - TP=8 - PP=35 - MBS=1 - GBS=2520 - NLS=105 - HS=20480 - NAH=128 - DDP=local - NNODES=315 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 1 " -elif [ ${MODEL_SIZE} == "1T" ]; then - TP=8 - PP=64 - MBS=1 - GBS=3072 - NLS=128 - HS=25600 - NAH=160 - DDP=local - NNODES=384 - MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " -else - echo "Invalid configuration" - exit 1 -fi - - -# Name of the job -export JOB_NAME=results_table_1_model_size_${MODEL_SIZE} - - -# Import the configs. -. `pwd`/CONFIG.sh - - -# Submit the job. -. `pwd`/SBATCH.sh - - -exit 0 - - - +#!/bin/bash + +# ================================ +# Choose the case to run. +# ================================ +# model size options = [1.7B, 3.6B, 7.5B, 18B, 39B, 76B, 145B, 310B, 530B, 1T] +MODEL_SIZE=1.7B + + + + + + +if [ ${MODEL_SIZE} == "1.7B" ]; then + TP=1 + PP=1 + MBS=16 + GBS=512 + NLS=24 + HS=2304 + NAH=24 + DDP=torch + NNODES=4 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +elif [ ${MODEL_SIZE} == "3.6B" ]; then + TP=2 + PP=1 + MBS=16 + GBS=512 + NLS=30 + HS=3072 + NAH=32 + DDP=torch + NNODES=8 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +elif [ ${MODEL_SIZE} == "7.5B" ]; then + TP=4 + PP=1 + MBS=16 + GBS=512 + NLS=36 + HS=4096 + NAH=32 + DDP=torch + NNODES=16 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +elif [ ${MODEL_SIZE} == "18B" ]; then + TP=8 + PP=1 + MBS=8 + GBS=1024 + NLS=40 + HS=6144 + NAH=48 + DDP=torch + NNODES=32 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +elif [ ${MODEL_SIZE} == "39B" ]; then + TP=8 + PP=2 + MBS=4 + GBS=1536 + NLS=48 + HS=8192 + NAH=64 + DDP=local + NNODES=64 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +elif [ ${MODEL_SIZE} == "76B" ]; then + TP=8 + PP=4 + MBS=2 + GBS=1792 + NLS=60 + HS=10240 + NAH=80 + DDP=local + NNODES=128 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 5" +elif [ ${MODEL_SIZE} == "145B" ]; then + TP=8 + PP=8 + MBS=2 + GBS=2304 + NLS=80 + HS=12288 + NAH=96 + DDP=local + NNODES=192 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 5 " +elif [ ${MODEL_SIZE} == "310B" ]; then + TP=8 + PP=16 + MBS=1 + GBS=2160 + NLS=96 + HS=16384 + NAH=128 + DDP=local + NNODES=240 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 3 " +elif [ ${MODEL_SIZE} == "530B" ]; then + TP=8 + PP=35 + MBS=1 + GBS=2520 + NLS=105 + HS=20480 + NAH=128 + DDP=local + NNODES=315 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 1 " +elif [ ${MODEL_SIZE} == "1T" ]; then + TP=8 + PP=64 + MBS=1 + GBS=3072 + NLS=128 + HS=25600 + NAH=160 + DDP=local + NNODES=384 + MEGATRON_EXTRA_PARAMS="--activations-checkpoint-method uniform " +else + echo "Invalid configuration" + exit 1 +fi + + +# Name of the job +export JOB_NAME=results_table_1_model_size_${MODEL_SIZE} + + +# Import the configs. +. `pwd`/CONFIG.sh + + +# Submit the job. +. `pwd`/SBATCH.sh + + +exit 0 + + + diff --git a/examples/bert/train_bert_340m_distributed.sh b/examples/bert/train_bert_340m_distributed.sh old mode 100644 new mode 100755 diff --git a/examples/export/ptq_and_trtllm_export/README.md b/examples/export/ptq_and_trtllm_export/README.md index abaa0d7645fcad39bc3c8f00f68df1453e1a66fb..2605910869852136e33357de115fb256d0e9fcb1 100644 --- a/examples/export/ptq_and_trtllm_export/README.md +++ b/examples/export/ptq_and_trtllm_export/README.md @@ -18,7 +18,7 @@ make -C docker release_build Once the container is built, install `nvidia-modelopt` and additional dependencies for sharded checkpoint support: ```sh pip install "nvidia-modelopt[all]~=0.13.0" --extra-index-url https://pypi.nvidia.com -pip install zarr tensorstore==0.1.45 +pip install zarr tensorstore!=0.1.46 ``` TensorRT-LLM quantization functionalities are currently packaged in `nvidia-modelopt`. You can find more documentation about `nvidia-modelopt` [here](https://nvidia.github.io/TensorRT-Model-Optimizer/). @@ -292,4 +292,4 @@ export trtllm_options=" \ trtllm-build ${trtllm_options} python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer mistralai/Mixtral-8x7B-v0.1 -``` \ No newline at end of file +``` diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama2_7b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama2_7b.sh old mode 100644 new mode 100755 diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_1_8b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_1_8b.sh old mode 100644 new mode 100755 diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh old mode 100644 new mode 100755 diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_minitron_8b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_minitron_8b.sh old mode 100644 new mode 100755 diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_mistral_12b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_mistral_12b.sh old mode 100644 new mode 100755 diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_mixtral_8x7b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_mixtral_8x7b.sh old mode 100644 new mode 100755 diff --git a/examples/gpt3/gpt_config.yaml b/examples/gpt3/gpt_config.yaml index 06257827fdfbd32d262d0da032930ebbaaf578aa..4f87c0af7dbc389e015bc6bd1cb74b945c18dd85 100644 --- a/examples/gpt3/gpt_config.yaml +++ b/examples/gpt3/gpt_config.yaml @@ -1,300 +1,301 @@ -# WARNING: Yaml configs is currently an experimental feature -language_model: - # model architecture - num_layers: 24 - hidden_size: 1024 - num_attention_heads: 16 - num_query_groups: null - - ffn_hidden_size: null - kv_channels: null - hidden_dropout: 0.0 - attention_dropout: 0.0 - fp32_residual_connection: False - - apply_residual_connection_post_layernorm: False - layernorm_epsilon: 1.e-5 - layernorm_zero_centered_gamma: True - add_bias_linear: False - bias_activation_fusion: False - add_qkv_bias: False - gated_linear_unit: False - activation_func: swiglu - num_moe_experts: null - rotary_interleaved: False - window_size: null - - # initialization - init_method: null - init_method_std: 0.02 - output_layer_init_method: null - - # mixed-precision - apply_query_key_layer_scaling: False - attention_softmax_in_fp32: False - - # fusion - bias_swiglu_fusion: True - masked_softmax_fusion: True - persist_layer_norm: False - memory_efficient_layer_norm: False - bias_dropout_fusion: True - apply_rope_fusion: True - - # activation recomputation - recompute_granularity: null - recompute_method: null - recompute_num_layers: null - distribute_saved_activations: null - - # fp8 related - fp8: null - fp8_margin: 0 - fp8_interval: 1 - fp8_amax_history_len: 1 - fp8_amax_compute_algo: "most_recent" - fp8_wgrad: True - - # miscellaneous - clone_scatter_output_in_embedding: True - - normalization: "LayerNorm" # alt value supported by TE: "RMSNorm" - - # MoE related - moe_router_load_balancing_type: "aux_loss" - moe_router_topk: 2 - moe_router_topk_limited_devices: null - moe_grouped_gemm: False - moe_aux_loss_coeff: 0 # 1e-2 would be a good start value for load balance loss. - moe_z_loss_coeff: null # 1e-3 would be a good start value for z-loss - moe_input_jitter_eps: null - moe_token_dropping: False - -model_parallel: - # Model parallelism - tensor_model_parallel_size: 1 - context_parallel_size: 1 - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null - sequence_parallel: True - expert_model_parallel_size: 1 - - # Initialization - perform_initialization: True - use_cpu_initialization: null - - # Training - fp16: False - bf16: True - params_dtype: null # Set from above arguments for core - timers: null - - # Optimizations - gradient_accumulation_fusion: True - async_tensor_model_parallel_allreduce: True - tp_comm_overlap: False - - # Debug Options - tp_comm_split_ag: True - tp_comm_atomic_ag: True - tp_comm_split_rs: True - tp_comm_atomic_rs: True - tp_comm_bulk_wgrad: True - tp_comm_bulk_dgrad: True - - # Parallelism - finalize_model_grads_func: null - - # Pipeline Parallel - pipeline_dtype: null - grad_scale_func: null - enable_autocast: False - autocast_dtype: null - variable_seq_lengths: False - num_microbatches_with_partial_activation_checkpoints: null - overlap_p2p_comm: False - batch_p2p_comm: True - batch_p2p_sync: True - use_ring_exchange_p2p: False - deallocate_pipeline_outputs: False - no_sync_func: null - grad_sync_func: null - param_sync_func: null - pipeline_model_parallel_split_rank: null - - # CPU Offloading - cpu_offloading: False - cpu_offloading_num_layers: 0 - _cpu_offloading_context: null - cpu_offloading_weights: False - cpu_offloading_activations: True - - # Timing - barrier_with_L1_time: True - -# training: -use_legacy_models: False -spec: null -micro_batch_size: 2 -global_batch_size: 128 -rampup_batch_size: [32, 32, 65324160] -check_for_nan_in_loss_and_grad: True -num_layers_per_virtual_pipeline_stage: null - -encoder_num_layers: null -decoder_num_layers: null -rotary_seq_len_interpolation_factor: null -add_position_embedding: False -make_vocab_size_divisible_by: 128 -group_query_attention: False - - -exit_signal_handler: False -exit_duration_in_mins: null -exit_interval: null - -untie_embeddings_and_output_weights: True -position_embedding_type: rope -rotary_percent: 0.5 -openai_gelu: False -squared_relu: False -swiglu: True -onnx_safe: null -bert_binary_head: True -max_position_embeddings: 4096 - -transformer_impl: local -use_flash_attn: False -seed: 1234 -data_parallel_random_init: False - -# Optimizer -optimizer: adam -lr: 2.5e-4 -lr_decay_style: cosine -lr_decay_iters: null -lr_decay_samples: 255126953 -lr_warmup_fraction: null -lr_warmup_iters: 0 -lr_warmup_samples: 81381 -lr_warmup_init: 0.0 -min_lr: 2.5e-5 -weight_decay: 0.1 -start_weight_decay: null -end_weight_decay: null -weight_decay_incr_style: constant -clip_grad: 1.0 -adam_beta1: 0.9 -adam_beta2: 0.95 -adam_eps: 1.e-08 -sgd_momentum: 0.9 -override_opt_param_scheduler: False -use_checkpoint_opt_param_scheduler: False - -# checkpointing arguments -save: null -save_interval: 20000 -no_save_optim: null -no_save_rng: null -load: null -no_load_optim: null -no_load_rng: null -finetune: False -use_checkpoint_args: False -exit_on_missing_checkpoint: False - -# loss arguments -loss_scale: null -initial_loss_scale: 4294967296 -min_loss_scale: 1.0 -loss_scale_window: 1000 -hysteresis: 2 -accumulate_allreduce_grads_in_fp32: False -fp16_lm_cross_entropy: False - -# distributed arguments -distributed_backend: nccl -distributed_timeout_minutes: 10 -overlap_grad_reduce: False -align_grad_reduce: True -overlap_param_gather: False -align_param_gather: False -scatter_gather_tensors_in_pipeline: True -local_rank: null -lazy_mpu_init: null -empty_unused_memory_level: 0 -standalone_embedding_stage: False -use_distributed_optimizer: False -nccl_communicator_config_path: null - -train_iters: null -eval_iters: 32 -eval_interval: 2000 -skip_train: False - -adlr_autoresume: False -adlr_autoresume_interval: 1000 - -# garbage collection -manual_gc: False -manual_gc_interval: 0 -manual_gc_eval: True - -tp_comm_overlap_cfg: null - -#data -data_path: null -split: '99,1,0' -train_data_path: null -valid_data_path: null -test_data_path: null -data_cache_path: null -mock_data: False -vocab_size: null -vocab_file: null -merge_file: null -vocab_extra_ids: 0 -seq_length: 4096 -encoder_seq_length: null -decoder_seq_length: null -retriever_seq_length: 256 -sample_rate: 1.0 -mask_prob: 0.15 -short_seq_prob: 0.1 -num_workers: 2 -tokenizer_type: GPTSentencePieceTokenizer -tokenizer_model: null -reset_position_ids: False -reset_attention_mask: False -eod_mask_loss: False -train_samples: 268554688 -dataloader_type: null - -#profile: -profile: False -profile_ranks: [0] -profile_step_end: 12 -profile_step_start: 10 - -#logging: -log_params_norm: True -log_num_zeros_in_grad: True -log_throughput: False -log_progress: False -timing_log_level: 0 -timing_log_option: minmax -tensorboard_log_interval: 1 -tensorboard_queue_size: 1000 -log_timers_to_tensorboard: False -log_validation_ppl_to_tensorboard: False -log_memory_to_tensorboard: False -log_world_size_to_tensorboard: False -log_loss_scale_to_tensorboard: True -wandb_project: '' -wandb_exp_name: '' -wandb_save_dir: '' -enable_one_logger: True -one_logger_project: megatron-lm -one_logger_run_name: null -log_interval: 100 -tensorboard_dir: null +# WARNING: Yaml configs is currently an experimental feature +language_model: + # model architecture + num_layers: 24 + hidden_size: 1024 + num_attention_heads: 16 + num_query_groups: null + + ffn_hidden_size: null + kv_channels: null + hidden_dropout: 0.0 + attention_dropout: 0.0 + fp32_residual_connection: False + + apply_residual_connection_post_layernorm: False + layernorm_epsilon: 1.e-5 + layernorm_zero_centered_gamma: True + add_bias_linear: False + bias_activation_fusion: False + add_qkv_bias: False + gated_linear_unit: False + activation_func: swiglu + num_moe_experts: null + rotary_interleaved: False + window_size: null + + # initialization + init_method: null + init_method_std: 0.02 + output_layer_init_method: null + + # mixed-precision + apply_query_key_layer_scaling: False + attention_softmax_in_fp32: False + + # fusion + bias_swiglu_fusion: True + masked_softmax_fusion: True + persist_layer_norm: False + memory_efficient_layer_norm: False + bias_dropout_fusion: True + apply_rope_fusion: True + + # activation recomputation + recompute_granularity: null + recompute_method: null + recompute_num_layers: null + distribute_saved_activations: null + + # fp8 related + fp8: null + fp8_margin: 0 + fp8_interval: 1 + fp8_amax_history_len: 1 + fp8_amax_compute_algo: "most_recent" + fp8_wgrad: True + + # miscellaneous + clone_scatter_output_in_embedding: True + + normalization: "LayerNorm" # alt value supported by TE: "RMSNorm" + + # MoE related + moe_router_load_balancing_type: "aux_loss" + moe_router_topk: 2 + moe_router_group_topk: null + moe_router_num_groups: null + moe_grouped_gemm: False + moe_aux_loss_coeff: 0 # 1e-2 would be a good start value for load balance loss. + moe_z_loss_coeff: null # 1e-3 would be a good start value for z-loss + moe_input_jitter_eps: null + moe_token_dropping: False + +model_parallel: + # Model parallelism + tensor_model_parallel_size: 1 + context_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + sequence_parallel: True + expert_model_parallel_size: 1 + + # Initialization + perform_initialization: True + use_cpu_initialization: null + + # Training + fp16: False + bf16: True + params_dtype: null # Set from above arguments for core + timers: null + + # Optimizations + gradient_accumulation_fusion: True + async_tensor_model_parallel_allreduce: True + tp_comm_overlap: False + + # Debug Options + tp_comm_split_ag: True + tp_comm_atomic_ag: True + tp_comm_split_rs: True + tp_comm_atomic_rs: True + tp_comm_bulk_wgrad: True + tp_comm_bulk_dgrad: True + + # Parallelism + finalize_model_grads_func: null + + # Pipeline Parallel + pipeline_dtype: null + grad_scale_func: null + enable_autocast: False + autocast_dtype: null + variable_seq_lengths: False + num_microbatches_with_partial_activation_checkpoints: null + overlap_p2p_comm: False + batch_p2p_comm: True + batch_p2p_sync: True + use_ring_exchange_p2p: False + deallocate_pipeline_outputs: False + no_sync_func: null + grad_sync_func: null + param_sync_func: null + pipeline_model_parallel_split_rank: null + + # CPU Offloading + cpu_offloading: False + cpu_offloading_num_layers: 0 + _cpu_offloading_context: null + cpu_offloading_weights: False + cpu_offloading_activations: True + + # Timing + barrier_with_L1_time: True + +# training: +use_legacy_models: False +spec: null +micro_batch_size: 2 +global_batch_size: 128 +rampup_batch_size: [32, 32, 65324160] +check_for_nan_in_loss_and_grad: True +num_layers_per_virtual_pipeline_stage: null + +encoder_num_layers: null +decoder_num_layers: null +rotary_seq_len_interpolation_factor: null +add_position_embedding: False +make_vocab_size_divisible_by: 128 +group_query_attention: False + + +exit_signal_handler: False +exit_duration_in_mins: null +exit_interval: null + +untie_embeddings_and_output_weights: True +position_embedding_type: rope +rotary_percent: 0.5 +openai_gelu: False +squared_relu: False +swiglu: True +onnx_safe: null +bert_binary_head: True +max_position_embeddings: 4096 + +transformer_impl: local +use_flash_attn: False +seed: 1234 +data_parallel_random_init: False + +# Optimizer +optimizer: adam +lr: 2.5e-4 +lr_decay_style: cosine +lr_decay_iters: null +lr_decay_samples: 255126953 +lr_warmup_fraction: null +lr_warmup_iters: 0 +lr_warmup_samples: 81381 +lr_warmup_init: 0.0 +min_lr: 2.5e-5 +weight_decay: 0.1 +start_weight_decay: null +end_weight_decay: null +weight_decay_incr_style: constant +clip_grad: 1.0 +adam_beta1: 0.9 +adam_beta2: 0.95 +adam_eps: 1.e-08 +sgd_momentum: 0.9 +override_opt_param_scheduler: False +use_checkpoint_opt_param_scheduler: False + +# checkpointing arguments +save: null +save_interval: 20000 +no_save_optim: null +no_save_rng: null +load: null +no_load_optim: null +no_load_rng: null +finetune: False +use_checkpoint_args: False +exit_on_missing_checkpoint: False + +# loss arguments +loss_scale: null +initial_loss_scale: 4294967296 +min_loss_scale: 1.0 +loss_scale_window: 1000 +hysteresis: 2 +accumulate_allreduce_grads_in_fp32: False +fp16_lm_cross_entropy: False + +# distributed arguments +distributed_backend: nccl +distributed_timeout_minutes: 10 +overlap_grad_reduce: False +align_grad_reduce: True +overlap_param_gather: False +align_param_gather: False +scatter_gather_tensors_in_pipeline: True +local_rank: null +lazy_mpu_init: null +empty_unused_memory_level: 0 +standalone_embedding_stage: False +use_distributed_optimizer: False +nccl_communicator_config_path: null + +train_iters: null +eval_iters: 32 +eval_interval: 2000 +skip_train: False + +adlr_autoresume: False +adlr_autoresume_interval: 1000 + +# garbage collection +manual_gc: False +manual_gc_interval: 0 +manual_gc_eval: True + +tp_comm_overlap_cfg: null + +#data +data_path: null +split: '99,1,0' +train_data_path: null +valid_data_path: null +test_data_path: null +data_cache_path: null +mock_data: False +vocab_size: null +vocab_file: null +merge_file: null +vocab_extra_ids: 0 +seq_length: 4096 +encoder_seq_length: null +decoder_seq_length: null +retriever_seq_length: 256 +sample_rate: 1.0 +mask_prob: 0.15 +short_seq_prob: 0.1 +num_workers: 2 +tokenizer_type: GPTSentencePieceTokenizer +tokenizer_model: null +reset_position_ids: False +reset_attention_mask: False +eod_mask_loss: False +train_samples: 268554688 +dataloader_type: null + +#profile: +profile: False +profile_ranks: [0] +profile_step_end: 12 +profile_step_start: 10 + +#logging: +log_params_norm: True +log_num_zeros_in_grad: True +log_throughput: False +log_progress: False +timing_log_level: 0 +timing_log_option: minmax +tensorboard_log_interval: 1 +tensorboard_queue_size: 1000 +log_timers_to_tensorboard: False +log_validation_ppl_to_tensorboard: False +log_memory_to_tensorboard: False +log_world_size_to_tensorboard: False +log_loss_scale_to_tensorboard: True +wandb_project: '' +wandb_exp_name: '' +wandb_save_dir: '' +enable_one_logger: True +one_logger_project: megatron-lm +one_logger_run_name: null +log_interval: 100 +tensorboard_dir: null diff --git a/examples/gpt3/hostfile_gpt_567B b/examples/gpt3/hostfile_gpt_567B new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/run_GPT-MOE_1nodes.sh b/examples/gpt3/run_gpt_567B_1nodes.sh old mode 100644 new mode 100755 similarity index 52% rename from run_GPT-MOE_1nodes.sh rename to examples/gpt3/run_gpt_567B_1nodes.sh index d38585e8623f014f084492c266566ea07c424f67..32c5ac5e07239108483f27c2c2657ca5cf1fadbb --- a/run_GPT-MOE_1nodes.sh +++ b/examples/gpt3/run_gpt_567B_1nodes.sh @@ -7,10 +7,10 @@ do fi done -mpirun -np 8 --allow-run-as-root \ - train_GPT-MOE_567B_1nodes.sh localhost --profiling=$profiling > output.log 2>&1 +mpirun -np 8 --allow-run-as-root \ + train_gpt_567B_1nodes.sh localhost --profiling=$profiling > output.log 2>&1 wait rm -rf CKPT -rm -rf mixtral_dataset/my-mixtral_text_document \ No newline at end of file +rm -rf mixtral_dataset/my-mixtral_text_document diff --git a/run_mixtral8x7B_2nodes.sh b/examples/gpt3/run_gpt_567B_multinodes.sh old mode 100644 new mode 100755 similarity index 67% rename from run_mixtral8x7B_2nodes.sh rename to examples/gpt3/run_gpt_567B_multinodes.sh index bd92ab0c26f82ad2a043a52aadb672d352151160..10821c467991dd816487a016437104e98ee2f788 --- a/run_mixtral8x7B_2nodes.sh +++ b/examples/gpt3/run_gpt_567B_multinodes.sh @@ -7,13 +7,13 @@ do fi done -mpirun -np 16 --hostfile mixtralnodes \ +mpirun -np 512 --hostfile hostfile_gpt_567B \ --allow-run-as-root \ --bind-to none \ --mca plm_rsh_no_tree_spawn 1 \ - train_mixtral_8x7B_2nodes.sh node021 --profiling=$profiling > output.log 2>&1 + train_gpt_567B_multinodes.sh node002 --profiling=$profiling > output.log 2>&1 wait rm -rf CKPT -#rm -rf mixtral_dataset/my-mixtral_text_document \ No newline at end of file +#rm -rf mixtral_dataset/my-mixtral_text_document diff --git a/examples/gpt3/train_gpt3_175b_distributed.sh b/examples/gpt3/train_gpt3_175b_distributed.sh old mode 100644 new mode 100755 index 7d2c01b315799ba70bdf7a29506d6e0f8d630afc..fbc7e384ed746acd71cdb362bd337d51a6ced33b --- a/examples/gpt3/train_gpt3_175b_distributed.sh +++ b/examples/gpt3/train_gpt3_175b_distributed.sh @@ -1,82 +1,82 @@ -#!/bin/bash - -# Runs the "175B" parameter model - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NUM_NODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) - -CHECKPOINT_PATH=$1 # -TENSORBOARD_LOGS_PATH=$2 # -VOCAB_FILE=$3 #/gpt2-vocab.json -MERGE_FILE=$4 #/gpt2-merges.txt -DATA_PATH=$5 #_text_document - -DISTRIBUTED_ARGS=( - --nproc_per_node $GPUS_PER_NODE - --nnodes $NUM_NODES - --master_addr $MASTER_ADDR - --master_port $MASTER_PORT -) - -GPT_MODEL_ARGS=( - --num-layers 96 - --hidden-size 12288 - --num-attention-heads 96 - --seq-length 2048 - --max-position-embeddings 2048 - --attention-backend auto # Can use (flash/fused/unfused/local) -) - -TRAINING_ARGS=( - --micro-batch-size 1 - --global-batch-size 1536 - --rampup-batch-size 16 16 5859375 - --train-iters 500000 - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.95 - --init-method-std 0.006 - --clip-grad 1.0 - --fp16 - --lr 6.0e-5 - --lr-decay-style cosine - --min-lr 6.0e-6 - --lr-warmup-fraction .001 - --lr-decay-iters 430000 -) - -MODEL_PARALLEL_ARGS=( - --tensor-model-parallel-size 8 - --pipeline-model-parallel-size 16 -) - -DATA_ARGS=( - --data-path $DATA_PATH - --vocab-file $VOCAB_FILE - --merge-file $MERGE_FILE - --split 949,50,1 -) - -EVAL_AND_LOGGING_ARGS=( - --log-interval 100 - --save-interval 10000 - --eval-interval 1000 - --save $CHECKPOINT_PATH - --load $CHECKPOINT_PATH - --eval-iters 10 - --tensorboard-dir $TENSORBOARD_LOGS_PATH -) - -torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \ - ${GPT_MODEL_ARGS[@]} \ - ${TRAINING_ARGS[@]} \ - ${MODEL_PARALLEL_ARGS[@]} \ - ${DATA_ARGS[@]} \ - ${EVAL_AND_LOGGING_ARGS[@]} +#!/bin/bash + +# Runs the "175B" parameter model + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NUM_NODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +CHECKPOINT_PATH=$1 # +TENSORBOARD_LOGS_PATH=$2 # +VOCAB_FILE=$3 #/gpt2-vocab.json +MERGE_FILE=$4 #/gpt2-merges.txt +DATA_PATH=$5 #_text_document + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NUM_NODES + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT +) + +GPT_MODEL_ARGS=( + --num-layers 96 + --hidden-size 12288 + --num-attention-heads 96 + --seq-length 2048 + --max-position-embeddings 2048 + --attention-backend auto # Can use (flash/fused/unfused/local) +) + +TRAINING_ARGS=( + --micro-batch-size 1 + --global-batch-size 1536 + --rampup-batch-size 16 16 5859375 + --train-iters 500000 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.95 + --init-method-std 0.006 + --clip-grad 1.0 + --fp16 + --lr 6.0e-5 + --lr-decay-style cosine + --min-lr 6.0e-6 + --lr-warmup-fraction .001 + --lr-decay-iters 430000 +) + +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size 8 + --pipeline-model-parallel-size 16 +) + +DATA_ARGS=( + --data-path $DATA_PATH + --vocab-file $VOCAB_FILE + --merge-file $MERGE_FILE + --split 949,50,1 +) + +EVAL_AND_LOGGING_ARGS=( + --log-interval 100 + --save-interval 10000 + --eval-interval 1000 + --save $CHECKPOINT_PATH + --load $CHECKPOINT_PATH + --eval-iters 10 + --tensorboard-dir $TENSORBOARD_LOGS_PATH +) + +torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \ + ${GPT_MODEL_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${EVAL_AND_LOGGING_ARGS[@]} diff --git a/train_GPT-MOE_567B_1nodes.sh b/examples/gpt3/train_gpt_567B_1nodes.sh old mode 100644 new mode 100755 similarity index 88% rename from train_GPT-MOE_567B_1nodes.sh rename to examples/gpt3/train_gpt_567B_1nodes.sh index 1eac615fd6a3db7949b1b72ef824b9db7ead3e46..40bd6b7fdb48d95b2de28c48c8627deeb0601bfc --- a/train_GPT-MOE_567B_1nodes.sh +++ b/examples/gpt3/train_gpt_567B_1nodes.sh @@ -4,18 +4,23 @@ for para in $* do if [[ $para == --profiling* ]];then profiling=${para#*=} - export GPU_FLUSH_ON_EXECUTION=1 - export HIP_DIRECT_DISPATCH=0 fi done +# Runs GPT 567B model source /opt/dtk/env.sh -# Runs Mixtral 8x7B model + +# defauat env +CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )" +MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR})) +export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH +export GLOG_minloglevel=3 export CUDA_DEVICE_MAX_CONNECTIONS=1 export HSA_FORCE_FINE_GRAIN_PCIE=1 export OMP_NUM_THREADS=1 export GPU_MAX_HW_QUEUES=10 +# nccl env export NCCL_ALGO=Ring export NCCL_MIN_NCHANNELS=32 export NCCL_MAX_NCHANNELS=32 @@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7 export NCCL_NET_GDR_READ=1 export RCCL_SDMA_COPY_ENABLE=0 export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 -#export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" +export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" + +# enable BatchLinear export GROUPED_GEMM_BatchLinear=1 -export GLOG_minloglevel=3 RANK=$OMPI_COMM_WORLD_RANK LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK @@ -96,7 +102,6 @@ TRAINING_ARGS=( --bf16 --overlap-param-gather --overlap-grad-reduce - #--tp-comm-overlap ) TORCH_PROFIE_ARGS=( @@ -104,18 +109,10 @@ TORCH_PROFIE_ARGS=( --profile-ranks 0 1 2 3 4 5 6 7 --profile-step-start 3 --profile-step-end 4 - --profile-dir torch_prof_gpt_1nodes + --profile-dir torch_prof_gpt_1nodes_tp2-pp1-ep8-ep_tp1 --use-pytorch-profiler ) -HIP_PROFIE_ARGS=( - --profile - --profile-ranks 0 1 2 3 4 5 6 7 - --profile-step-start 4 - --profile-step-end 5 - --use-hip-profiler -) - MODEL_PARALLEL_ARGS=( --tensor-model-parallel-size 2 --pipeline-model-parallel-size 1 @@ -157,10 +154,6 @@ APP="python3 -u pretrain_gpt.py \ if [[ $profiling == "torch" ]]; then APP+=" ${TORCH_PROFIE_ARGS[@]}" -elif [[ $profiling == "hip" ]]; then - mkdir -p hip_prof_data - APP+=" ${HIP_PROFIE_ARGS[@]}" - APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}" fi #for hygon cpu @@ -205,4 +198,4 @@ case ${LOCAL_RANK} in ${APP} #numactl --cpunodebind=7 --membind=7 ${APP} ;; -esac \ No newline at end of file +esac diff --git a/train_GPT-MOE_567B.sh b/examples/gpt3/train_gpt_567B_multinodes.sh old mode 100644 new mode 100755 similarity index 87% rename from train_GPT-MOE_567B.sh rename to examples/gpt3/train_gpt_567B_multinodes.sh index f298fa8e97ea91eab678391d50d8f964029eb38f..9751e0888db79b3179ea076510967811bd51904c --- a/train_GPT-MOE_567B.sh +++ b/examples/gpt3/train_gpt_567B_multinodes.sh @@ -4,18 +4,23 @@ for para in $* do if [[ $para == --profiling* ]];then profiling=${para#*=} - export GPU_FLUSH_ON_EXECUTION=1 - export HIP_DIRECT_DISPATCH=0 fi done +# Runs GPT 567B model source /opt/dtk/env.sh -# Runs Mixtral 8x7B model + +# defauat env +CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )" +MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR})) +export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH +export GLOG_minloglevel=3 export CUDA_DEVICE_MAX_CONNECTIONS=1 export HSA_FORCE_FINE_GRAIN_PCIE=1 export OMP_NUM_THREADS=1 export GPU_MAX_HW_QUEUES=10 +# nccl env export NCCL_ALGO=Ring export NCCL_MIN_NCHANNELS=32 export NCCL_MAX_NCHANNELS=32 @@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7 export NCCL_NET_GDR_READ=1 export RCCL_SDMA_COPY_ENABLE=0 export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 -#export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" +export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" + +# enable BatchLinear export GROUPED_GEMM_BatchLinear=1 -export GLOG_minloglevel=3 RANK=$OMPI_COMM_WORLD_RANK LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK @@ -49,7 +55,7 @@ MODEL_ARGS=( --disable-bias-linear --seq-length 8192 --max-position-embeddings 32768 - --num-layers 64 + --num-layers 32 #64 --hidden-size 8192 --ffn-hidden-size 32768 --num-attention-heads 64 @@ -72,7 +78,7 @@ MOE_ARGS=( --moe-token-dispatcher-type alltoall --moe-expert-capacity-factor 0.5 --moe-pad-expert-input-to-capacity - --moe-grouped-gemm + #--moe-grouped-gemm ) DATA_ARGS=( @@ -84,7 +90,7 @@ DATA_ARGS=( TRAINING_ARGS=( --micro-batch-size 1 - --global-batch-size 4096 + --global-batch-size 1024 --lr 1e-4 --train-iters 10 --lr-decay-iters 320000 @@ -96,7 +102,6 @@ TRAINING_ARGS=( --bf16 --overlap-param-gather --overlap-grad-reduce - #--tp-comm-overlap ) TORCH_PROFIE_ARGS=( @@ -104,23 +109,16 @@ TORCH_PROFIE_ARGS=( --profile-ranks 0 1 2 3 4 5 6 7 --profile-step-start 3 --profile-step-end 4 - --profile-dir torch_prof_gpt + --profile-dir torch_prof_gpt_64nodes_tp2-pp16-ep16-ep_tp1-cp2 --use-pytorch-profiler ) -HIP_PROFIE_ARGS=( - --profile - --profile-ranks 0 1 2 3 4 5 6 7 - --profile-step-start 4 - --profile-step-end 5 - --use-hip-profiler -) - MODEL_PARALLEL_ARGS=( --tensor-model-parallel-size 2 --pipeline-model-parallel-size 16 --expert-model-parallel-size 16 --expert-tensor-parallel-size 1 + --context-parallel-size 2 --use-distributed-optimizer --sequence-parallel ) @@ -157,10 +155,6 @@ APP="python3 -u pretrain_gpt.py \ if [[ $profiling == "torch" ]]; then APP+=" ${TORCH_PROFIE_ARGS[@]}" -elif [[ $profiling == "hip" ]]; then - mkdir -p hip_prof_data - APP+=" ${HIP_PROFIE_ARGS[@]}" - APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}" fi #for hygon cpu diff --git a/examples/inference/gpt/gpt_batch_inference.py b/examples/inference/gpt/gpt_batch_inference.py index 050b230cef70d56203b7f9270a6166d7251f0769..604408bcbcf4a91921bca0b1ec5f9fbcb42c6a48 100644 --- a/examples/inference/gpt/gpt_batch_inference.py +++ b/examples/inference/gpt/gpt_batch_inference.py @@ -1,115 +1,200 @@ -import os -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig -from pretrain_gpt import model_provider -import torch -import sys -from argparse import Namespace -from megatron.core.inference.engines.abstract_engine import AbstractEngine -from megatron.core.inference.engines.mcore_engine import MCoreEngine -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper -from megatron.core.inference.inference_request import InferenceRequest -from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController -from megatron.core.transformer.module import MegatronModule -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - os.path.pardir, os.path.pardir))) - -from megatron.training import get_args -from megatron.training import get_tokenizer -from megatron.training.checkpointing import load_checkpoint -from megatron.core import mpu -from megatron.training.initialize import initialize_megatron -from megatron.training import get_model -from typing import List - -def add_text_generate_args(parser): - """Text generation arguments.""" - group = parser.add_argument_group(title='text generation') - - group.add_argument("--temperature", type=float, default=1.0, - help='Sampling temperature.') - group.add_argument("--top_k", type=int, default=1, - help='Top k sampling.') - group.add_argument("--top_p", type=float, default=0.0, - help='Top p sampling.') - group.add_argument("--return-log-probs", action='store_true', default=False, - help='Return the log probabilities of the final output tokens') - group.add_argument("--num-tokens-to-generate", type=int, default=30, - help='Number of tokens to generate for each prompt') - group.add_argument("--prompts", metavar='N', type=str, nargs='+', - help='Input prompts with each prompt within quotes and seperated by space') - group.add_argument("--max-batch-size", type=int, default=1, - help='Max number of prompts to process at once') - return parser - - -def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine: - """Utility to get the relevant backend for running inference - - This function will automatically chose the TRTLLMBackend when possible, and if not revert to Mcore backend if the user does not specify any backends. TRT LLM Backend is not implmented yet. - - Args: - args (Namespace): The user arguments parsed from command line - model (MegatronModule): The megatron model . - - Returns: - AbstractBackend: The chosen backend - """ - tokenizer = get_tokenizer() - - inference_wrapper_config = InferenceWrapperConfig( - hidden_size=args.hidden_size, - inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold, - fp32_residual_connection=args.fp32_residual_connection, - params_dtype=args.params_dtype, - padded_vocab_size=args.padded_vocab_size - ) - - inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config) - text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer) - return MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size) - -def main(): - """Main program.""" - - # Note: The default args passed here can be overwritten by using appropriate params (check arguments.py file) - # Micro batch size is not needed to be set by user. (It is calculated based on inference-batch-times-seqlen-threshold argument) - initialize_megatron(extra_args_provider=add_text_generate_args, - args_defaults={'no_load_rng': True, - 'no_load_optim': True, - 'micro_batch_size': 1, - 'exit_on_missing_checkpoint': True}) - - # Set up model and load checkpoint - model = get_model(model_provider, wrap_with_ddp=False) - load_checkpoint(model, None, None) - model = model[0] - - args = get_args() - - inference_engine = get_inference_engine(args, model) - - sampling_params = SamplingParams( - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - return_log_probs=args.return_log_probs, - num_tokens_to_generate=args.num_tokens_to_generate) - - results: List[InferenceRequest] = inference_engine.generate( - prompts=args.prompts, sampling_params=sampling_params - ) - - if torch.distributed.get_rank() == 0: - for idx, result in enumerate(results): - print(f' \n------------- RESULT FOR PROMPT {idx} --------------- ') - result = { - 'id': result.request_id, - 'input_prompt': result.prompt, - 'generated_text': result.generated_text, - 'generated_tokens' : result.generated_tokens - } - print(result) - -if __name__ == "__main__": - main() +import os +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from pretrain_gpt import model_provider +import torch +import sys +import time +import tqdm +import warnings +from argparse import Namespace +from megatron.core.inference.engines.abstract_engine import AbstractEngine +from megatron.core.inference.engines.mcore_engine import MCoreEngine +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, +) +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) +from megatron.core.transformer.module import MegatronModule + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) +) + +from megatron.training import get_args +from megatron.training import get_tokenizer +from megatron.training.checkpointing import load_checkpoint +from megatron.core import mpu +from megatron.training.initialize import initialize_megatron +from megatron.training import get_model +import asyncio +from typing import AsyncIterator, List + + + +def add_text_generate_args(parser): + """Text generation arguments.""" + group = parser.add_argument_group(title='text generation') + + group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') + group.add_argument("--top_k", type=int, default=1, help='Top k sampling.') + group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') + group.add_argument( + "--return-log-probs", + action='store_true', + default=False, + help='Return the log probabilities of the final output tokens', + ) + group.add_argument( + "--num-tokens-to-generate", + type=int, + default=30, + help='Number of tokens to generate for each prompt', + ) + group.add_argument( + "--prompts", + metavar='N', + type=str, + nargs='+', + help='Input prompts with each prompt within quotes and seperated by space', + ) + group.add_argument( + "--max-batch-size", type=int, default=8, dest="inference_max_requests", + help='Max number of prompts to process at once' + ) + group.add_argument("--stream", action="store_true", default=False, help="Stream output tokens") + return parser + + +def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine: + """Utility to get the relevant backend for running inference + + This function will automatically chose the TRTLLMBackend when possible, and if not revert to Mcore backend if the user does not specify any backends. TRT LLM Backend is not implmented yet. + + Args: + args (Namespace): The user arguments parsed from command line + model (MegatronModule): The megatron model . + + Returns: + AbstractBackend: The chosen backend + """ + tokenizer = get_tokenizer() + + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=args.hidden_size, + inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold, + fp32_residual_connection=args.fp32_residual_connection, + params_dtype=args.params_dtype, + padded_vocab_size=args.padded_vocab_size, + inference_max_requests=args.inference_max_requests, + inference_max_seq_length=args.inference_max_seq_length, + ) + + inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config) + text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer) + return MCoreEngine(text_generation_controller=text_generation_controller) + + +async def generate( + inference_engine: MCoreEngine, + sampling_params: SamplingParams, + prompts: List[str], +) -> List[InferenceRequest]: + async def collect_stream(prompt, request_id, stream_generator): + print(f"Request {request_id}: {prompt}", end="", flush=True) + prev_idx = 0 + async for output in stream_generator: + print(output.generated_text[prev_idx:], end="", flush=True) + prev_idx = len(output.generated_text) + print() + + request_ids: List[str] = [ + inference_engine.add_request( + prompt=prompt, inference_parameters=sampling_params, streaming=True + ) + for prompt in prompts + ] + stream_generators = [inference_engine.get_stream_generator(request_id) for request_id in request_ids] + + tasks = [ + asyncio.create_task(collect_stream(prompt, request_id, stream_generator)) + for (prompt, request_id, stream_generator) in zip(prompts, request_ids, stream_generators) + ] + + await inference_engine.run_engine_async() + await asyncio.gather(*tasks) + + results: List[InferenceRequest] = [ + inference_engine.scheduler.completed_request_pool[request_id] for request_id in request_ids + ] + + return results + +def main(): + """Main program.""" + + # Note: The default args passed here can be overwritten by using appropriate params (check arguments.py file) + # Micro batch size is not needed to be set by user. (It is calculated based on inference-batch-times-seqlen-threshold argument) + initialize_megatron( + extra_args_provider=add_text_generate_args, + args_defaults={ + 'no_load_rng': True, + 'no_load_optim': True, + 'micro_batch_size': 1, + 'exit_on_missing_checkpoint': True, + }, + ) + + # Set up model and load checkpoint + model = get_model(model_provider, wrap_with_ddp=False) + load_checkpoint(model, None, None) + model = model[0] + + args = get_args() + + inference_engine = get_inference_engine(args, model) + + sampling_params = SamplingParams( + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + return_log_probs=args.return_log_probs, + num_tokens_to_generate=args.num_tokens_to_generate, + ) + + if args.enable_cuda_graph: + print(f"Running warmup for CUDA graphs...") + inference_engine.generate( + prompts=args.prompts, sampling_params=sampling_params + ) + + start_time = time.perf_counter() + if args.stream: + results: List[InferenceRequest] = asyncio.run(generate(inference_engine, sampling_params, args.prompts)) + else: + results: List[InferenceRequest] = inference_engine.generate( + prompts=args.prompts, sampling_params=sampling_params, + ) + end_time = time.perf_counter() + latency = end_time - start_time + + if torch.distributed.get_rank() == 0: + for idx, result in enumerate(results): + print(f' \n------------- RESULT FOR PROMPT {idx} --------------- ') + result = { + 'id': result.request_id, + 'input_prompt': result.prompt, + 'generated_text': result.generated_text, + 'generated_tokens': result.generated_tokens, + 'latency': latency, + } + print(result) + + torch.distributed.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/examples/inference/llama_mistral/run_text_generation_llama3.1.sh b/examples/inference/llama_mistral/run_text_generation_llama3.1.sh old mode 100644 new mode 100755 index 06584f0917d157f4d8c91323d75c780bd058fc16..08db907c57d3656429716e689233f44634bb8c74 --- a/examples/inference/llama_mistral/run_text_generation_llama3.1.sh +++ b/examples/inference/llama_mistral/run_text_generation_llama3.1.sh @@ -1,56 +1,56 @@ -#!/bin/bash -# This example will start serving the Llama3.1-8B model -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NVTE_APPLY_QK_LAYER_SCALING=0 - -DISTRIBUTED_ARGS="--nproc_per_node 1 \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr 0.0.0.0 \ - --master_port 6000" - -# Ensure CHECKPOINT and TOKENIZER_MODEL are provided -if [ -z "$1" ] || [ -z "$2" ]; then - echo "Error: You must provide CHECKPOINT and TOKENIZER_MODEL as command-line arguments." - echo "Usage: $0 /path/to/checkpoint /path/to/tokenizer_model" - exit 1 -fi - -# Assign command-line arguments to variables -CHECKPOINT=$1 -TOKENIZER_MODEL=$2 - -pip install flask-restful - -torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ - --use-checkpoint-args \ - --disable-bias-linear \ - --tokenizer-type HuggingFaceTokenizer \ - --tokenizer-model ${TOKENIZER_MODEL} \ - --transformer-impl transformer_engine \ - --normalization RMSNorm \ - --group-query-attention \ - --num-query-groups 8 \ - --no-masked-softmax-fusion \ - --attention-softmax-in-fp32 \ - --attention-dropout 0.0 \ - --hidden-dropout 0.0 \ - --untie-embeddings-and-output-weights \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 500000 \ - --use-rope-scaling \ - --use-rotary-position-embeddings \ - --swiglu \ - --tensor-model-parallel-size 1 \ - --pipeline-model-parallel-size 1 \ - --num-layers 32 \ - --hidden-size 4096 \ - --ffn-hidden-size 14336 \ - --load ${CHECKPOINT} \ - --num-attention-heads 32 \ - --max-position-embeddings 131072 \ - --bf16 \ - --micro-batch-size 1 \ - --seq-length 8192 +#!/bin/bash +# This example will start serving the Llama3.1-8B model +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 + +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 0.0.0.0 \ + --master_port 6000" + +# Ensure CHECKPOINT and TOKENIZER_MODEL are provided +if [ -z "$1" ] || [ -z "$2" ]; then + echo "Error: You must provide CHECKPOINT and TOKENIZER_MODEL as command-line arguments." + echo "Usage: $0 /path/to/checkpoint /path/to/tokenizer_model" + exit 1 +fi + +# Assign command-line arguments to variables +CHECKPOINT=$1 +TOKENIZER_MODEL=$2 + +pip install flask-restful + +torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --use-checkpoint-args \ + --disable-bias-linear \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 500000 \ + --use-rope-scaling \ + --use-rotary-position-embeddings \ + --swiglu \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 14336 \ + --load ${CHECKPOINT} \ + --num-attention-heads 32 \ + --max-position-embeddings 131072 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length 8192 diff --git a/examples/inference/llama_mistral/run_text_generation_llama3.sh b/examples/inference/llama_mistral/run_text_generation_llama3.sh old mode 100644 new mode 100755 index c5fc4103ab54dd34cb79fb65e4eb535328bd2e0a..fb233772ce9aa5e2eadf403e0d193b76f06197b7 --- a/examples/inference/llama_mistral/run_text_generation_llama3.sh +++ b/examples/inference/llama_mistral/run_text_generation_llama3.sh @@ -1,55 +1,55 @@ -#!/bin/bash -# This example will start serving the Llama3-8B model -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NVTE_APPLY_QK_LAYER_SCALING=0 - -DISTRIBUTED_ARGS="--nproc_per_node 1 \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr 0.0.0.0 \ - --master_port 6000" - -# Ensure CHECKPOINT and TOKENIZER_MODEL are provided -if [ -z "$1" ] || [ -z "$2" ]; then - echo "Error: You must provide CHECKPOINT and TOKENIZER_MODEL as command-line arguments." - echo "Usage: $0 /path/to/checkpoint /path/to/tokenizer_model" - exit 1 -fi - -# Assign command-line arguments to variables -CHECKPOINT=$1 -TOKENIZER_MODEL=$2 - -pip install flask-restful - -torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ - --use-checkpoint-args \ - --disable-bias-linear \ - --tokenizer-type HuggingFaceTokenizer \ - --tokenizer-model ${TOKENIZER_MODEL} \ - --transformer-impl transformer_engine \ - --normalization RMSNorm \ - --group-query-attention \ - --num-query-groups 8 \ - --no-masked-softmax-fusion \ - --attention-softmax-in-fp32 \ - --attention-dropout 0.0 \ - --hidden-dropout 0.0 \ - --untie-embeddings-and-output-weights \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 500000 \ - --use-rotary-position-embeddings \ - --swiglu \ - --tensor-model-parallel-size 1 \ - --pipeline-model-parallel-size 1 \ - --num-layers 32 \ - --hidden-size 4096 \ - --ffn-hidden-size 14336 \ - --load ${CHECKPOINT} \ - --num-attention-heads 32 \ - --max-position-embeddings 8192 \ - --bf16 \ - --micro-batch-size 1 \ - --seq-length 8192 +#!/bin/bash +# This example will start serving the Llama3-8B model +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 + +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 0.0.0.0 \ + --master_port 6000" + +# Ensure CHECKPOINT and TOKENIZER_MODEL are provided +if [ -z "$1" ] || [ -z "$2" ]; then + echo "Error: You must provide CHECKPOINT and TOKENIZER_MODEL as command-line arguments." + echo "Usage: $0 /path/to/checkpoint /path/to/tokenizer_model" + exit 1 +fi + +# Assign command-line arguments to variables +CHECKPOINT=$1 +TOKENIZER_MODEL=$2 + +pip install flask-restful + +torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --use-checkpoint-args \ + --disable-bias-linear \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 500000 \ + --use-rotary-position-embeddings \ + --swiglu \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 14336 \ + --load ${CHECKPOINT} \ + --num-attention-heads 32 \ + --max-position-embeddings 8192 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length 8192 diff --git a/examples/inference/llama_mistral/run_text_generation_mistral.sh b/examples/inference/llama_mistral/run_text_generation_mistral.sh old mode 100644 new mode 100755 index 4358fd494c7029b94d2f898f6618c0bc24c78c81..050de7993199a7ba3647a47b1dbe8b5597026b13 --- a/examples/inference/llama_mistral/run_text_generation_mistral.sh +++ b/examples/inference/llama_mistral/run_text_generation_mistral.sh @@ -1,53 +1,53 @@ -#!/bin/bash -# This example will start serving the Mistral-7B-v0.3 model -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -DISTRIBUTED_ARGS="--nproc_per_node 1 \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr 0.0.0.0 \ - --master_port 6000" - -# Ensure CHECKPOINT and TOKENIZER_MODEL are provided -if [ -z "$1" ] || [ -z "$2" ]; then - echo "Error: You must provide CHECKPOINT and TOKENIZER_MODEL as command-line arguments." - echo "Usage: $0 /path/to/checkpoint /path/to/tokenizer_model" - exit 1 -fi - -# Assign command-line arguments to variables -CHECKPOINT=$1 -TOKENIZER_MODEL=$2 - -pip install flask-restful - -torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ - --tokenizer-type HuggingFaceTokenizer \ - --tokenizer-model ${TOKENIZER_MODEL} \ - --use-checkpoint-args \ - --apply-layernorm-1p \ - --transformer-impl transformer_engine \ - --normalization RMSNorm \ - --group-query-attention \ - --num-query-groups 8 \ - --no-masked-softmax-fusion \ - --use-flash-attn \ - --untie-embeddings-and-output-weights \ - --disable-bias-linear \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --swiglu \ - --ffn-hidden-size 14336 \ - --tensor-model-parallel-size 1 \ - --pipeline-model-parallel-size 1 \ - --num-layers 32 \ - --hidden-size 4096 \ - --load ${CHECKPOINT} \ - --num-attention-heads 32 \ - --max-position-embeddings 4096 \ - --bf16 \ - --micro-batch-size 1 \ - --seq-length 4096 \ - --seed 101 +#!/bin/bash +# This example will start serving the Mistral-7B-v0.3 model +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 0.0.0.0 \ + --master_port 6000" + +# Ensure CHECKPOINT and TOKENIZER_MODEL are provided +if [ -z "$1" ] || [ -z "$2" ]; then + echo "Error: You must provide CHECKPOINT and TOKENIZER_MODEL as command-line arguments." + echo "Usage: $0 /path/to/checkpoint /path/to/tokenizer_model" + exit 1 +fi + +# Assign command-line arguments to variables +CHECKPOINT=$1 +TOKENIZER_MODEL=$2 + +pip install flask-restful + +torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --use-checkpoint-args \ + --apply-layernorm-1p \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --use-flash-attn \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --ffn-hidden-size 14336 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --load ${CHECKPOINT} \ + --num-attention-heads 32 \ + --max-position-embeddings 4096 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length 4096 \ + --seed 101 diff --git a/examples/inference/run_text_generation_server_345M.sh b/examples/inference/run_text_generation_server_345M.sh old mode 100644 new mode 100755 index e8e61adb163924f8ba9eed4a653d47fe9b0ee43a..2394710ea9f962d91fa11b441d46bc342b67b95f --- a/examples/inference/run_text_generation_server_345M.sh +++ b/examples/inference/run_text_generation_server_345M.sh @@ -1,31 +1,31 @@ -#!/bin/bash -# This example will start serving the 345M model. -DISTRIBUTED_ARGS="--nproc_per_node 1 \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr localhost \ - --master_port 6000" - -CHECKPOINT= -VOCAB_FILE= -MERGE_FILE= - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -pip install flask-restful - -torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ - --tensor-model-parallel-size 1 \ - --pipeline-model-parallel-size 1 \ - --num-layers 24 \ - --hidden-size 1024 \ - --load ${CHECKPOINT} \ - --num-attention-heads 16 \ - --max-position-embeddings 1024 \ - --tokenizer-type GPT2BPETokenizer \ - --fp16 \ - --micro-batch-size 1 \ - --seq-length 1024 \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --seed 42 +#!/bin/bash +# This example will start serving the 345M model. +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 6000" + +CHECKPOINT= +VOCAB_FILE= +MERGE_FILE= + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +pip install flask-restful + +torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 24 \ + --hidden-size 1024 \ + --load ${CHECKPOINT} \ + --num-attention-heads 16 \ + --max-position-embeddings 1024 \ + --tokenizer-type GPT2BPETokenizer \ + --fp16 \ + --micro-batch-size 1 \ + --seq-length 1024 \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --seed 42 diff --git a/examples/inference/run_text_generation_server_345M_8_tensor_parallel.sh b/examples/inference/run_text_generation_server_345M_8_tensor_parallel.sh old mode 100644 new mode 100755 index 368cec3b312f05807ac9b050895bd832fe2ecb4f..8ca0c4194bc9bfa4b8e292180eb6cd0961ccc8c6 --- a/examples/inference/run_text_generation_server_345M_8_tensor_parallel.sh +++ b/examples/inference/run_text_generation_server_345M_8_tensor_parallel.sh @@ -1,29 +1,29 @@ -#!/bin/bash -# This example will start serving the 345M model that is partitioned 8 way tensor parallel -DISTRIBUTED_ARGS="--nproc_per_node 8 \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr localhost \ - --master_port 6000" - -CHECKPOINT= -VOCAB_FILE= -MERGE_FILE= - -pip install flask-restful - -python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ - --tensor-model-parallel-size 8 \ - --pipeline-model-parallel-size 1 \ - --num-layers 24 \ - --hidden-size 1024 \ - --load ${CHECKPOINT} \ - --num-attention-heads 16 \ - --max-position-embeddings 1024 \ - --tokenizer-type GPT2BPETokenizer \ - --fp16 \ - --micro-batch-size 1 \ - --seq-length 1024 \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --seed 42 +#!/bin/bash +# This example will start serving the 345M model that is partitioned 8 way tensor parallel +DISTRIBUTED_ARGS="--nproc_per_node 8 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 6000" + +CHECKPOINT= +VOCAB_FILE= +MERGE_FILE= + +pip install flask-restful + +python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --num-layers 24 \ + --hidden-size 1024 \ + --load ${CHECKPOINT} \ + --num-attention-heads 16 \ + --max-position-embeddings 1024 \ + --tokenizer-type GPT2BPETokenizer \ + --fp16 \ + --micro-batch-size 1 \ + --seq-length 1024 \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --seed 42 diff --git a/examples/mamba/run_text_gen_server_8b.sh b/examples/mamba/run_text_gen_server_8b.sh old mode 100644 new mode 100755 index 8d3137f24429a0a1bb1ec3cc6febc03804c20455..5c712ffae701890e688a6bff2fed8479ed25a5fd --- a/examples/mamba/run_text_gen_server_8b.sh +++ b/examples/mamba/run_text_gen_server_8b.sh @@ -1,50 +1,50 @@ -#!/bin/bash - -# Use: ./run_text_gen_server_8b.sh -# To launch the client: python ../../tools/text_generation_cli.py - -CHECKPOINT_PATH=$1 -TOKENIZER_PATH=$2 - -DISTRIBUTED_ARGS="--nproc_per_node 1 \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr localhost \ - --master_port 6000" - -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NCCL_IB_TIMEOUT=19 -export NCCL_IB_QPS_PER_CONNECTION=4 - -export TRITON_CACHE_DIR="./triton-cache/" -export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" - -torchrun $DISTRIBUTED_ARGS ../../tools/run_mamba_text_generation_server.py \ - --tensor-model-parallel-size 1 \ - --pipeline-model-parallel-size 1 \ - --untie-embeddings-and-output-weights \ - --num-layers 56 \ - --hidden-size 4096 \ - --load ${CHECKPOINT_PATH} \ - --num-attention-heads 32 \ - --group-query-attention \ - --num-query-groups 8 \ - --hybrid-attention-ratio 0.08 \ - --hybrid-mlp-ratio 0.5 \ - --attention-dropout 0.0 \ - --hidden-dropout 0.0 \ - --disable-bias-linear \ - --normalization RMSNorm \ - --seq-length 4096 \ - --max-position-embeddings 4096 \ - --position-embedding-type none \ - --tokenizer-type GPTSentencePieceTokenizer \ - --tokenizer-model ${TOKENIZER_PATH} \ - --distributed-backend nccl \ - --distributed-timeout-minutes 1440 \ - --bf16 \ - --micro-batch-size 1 \ - --use-mcore-models \ - --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ - --seed 42 +#!/bin/bash + +# Use: ./run_text_gen_server_8b.sh +# To launch the client: python ../../tools/text_generation_cli.py + +CHECKPOINT_PATH=$1 +TOKENIZER_PATH=$2 + +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 6000" + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_TIMEOUT=19 +export NCCL_IB_QPS_PER_CONNECTION=4 + +export TRITON_CACHE_DIR="./triton-cache/" +export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" + +torchrun $DISTRIBUTED_ARGS ../../tools/run_mamba_text_generation_server.py \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --untie-embeddings-and-output-weights \ + --num-layers 56 \ + --hidden-size 4096 \ + --load ${CHECKPOINT_PATH} \ + --num-attention-heads 32 \ + --group-query-attention \ + --num-query-groups 8 \ + --hybrid-attention-ratio 0.08 \ + --hybrid-mlp-ratio 0.5 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --disable-bias-linear \ + --normalization RMSNorm \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --position-embedding-type none \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model ${TOKENIZER_PATH} \ + --distributed-backend nccl \ + --distributed-timeout-minutes 1440 \ + --bf16 \ + --micro-batch-size 1 \ + --use-mcore-models \ + --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --seed 42 diff --git a/examples/mamba/run_text_gen_server_8b_gpt3.sh b/examples/mamba/run_text_gen_server_8b_gpt3.sh old mode 100644 new mode 100755 diff --git a/examples/mamba/train.sh b/examples/mamba/train.sh old mode 100644 new mode 100755 index 3952a997d479e732d6e08545402103286518d97f..033a000351aade51fa7c303cb43675d941ada469 --- a/examples/mamba/train.sh +++ b/examples/mamba/train.sh @@ -1,105 +1,105 @@ -#!/bin/bash - -# Use: ./train.sh - -MODEL_SCALE="800M" # or "8B" - -case "${MODEL_SCALE}" in - "800M") - TENSOR_MODEL_PARALLEL_SIZE=1 - NUM_LAYERS=48 - HIDDEN_SIZE=1024 - NUM_ATTENTION_HEADS=16 - GLOBAL_BATCH_SIZE=32 - ;; - "8B") - TENSOR_MODEL_PARALLEL_SIZE=4 - NUM_LAYERS=56 - HIDDEN_SIZE=4096 - NUM_ATTENTION_HEADS=32 - GLOBAL_BATCH_SIZE=8 - ;; - *) - echo "Invalid version specified" - exit 1 - ;; -esac - -DATA_PATH=$1 -TOKENIZER_PATH=$2 - -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NCCL_IB_TIMEOUT=19 -export NCCL_IB_QPS_PER_CONNECTION=4 - -CHECKPOINT_DIR="./checkpoints" -DATACACHE_DIR="./data-cache" -TENSORBOARD_DIR="./tensorboard" - -mkdir -p ${CHECKPOINT_DIR} -mkdir -p ${DATACACHE_DIR} -mkdir -p ${TENSORBOARD_DIR} - -export TRITON_CACHE_DIR="./triton-cache/" -export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" - -SEQ_LEN=4096 -TRAIN_SAMPLES=73242188 # 300B tokens / 4096 -LR_WARMUP_SAMPLES=50000 -LR_DECAY_SAMPLES=73192188 # TRAIN_SAMPLES - LR_WARMUP_SAMPLES - -options=" \ - --tensor-model-parallel-size ${TENSOR_MODEL_PARALLEL_SIZE} \ - --sequence-parallel \ - --pipeline-model-parallel-size 1 \ - --use-distributed-optimizer \ - --overlap-param-gather \ - --overlap-grad-reduce \ - --untie-embeddings-and-output-weights \ - --init-method-std 0.02 \ - --position-embedding-type none \ - --num-layers ${NUM_LAYERS} \ - --hidden-size ${HIDDEN_SIZE} \ - --num-attention-heads ${NUM_ATTENTION_HEADS} \ - --group-query-attention \ - --num-query-groups 8 \ - --hybrid-attention-ratio 0.08 \ - --hybrid-mlp-ratio 0.5 \ - --seq-length ${SEQ_LEN} \ - --max-position-embeddings ${SEQ_LEN} \ - --train-samples ${TRAIN_SAMPLES} \ - --lr-warmup-samples ${LR_WARMUP_SAMPLES} \ - --lr-decay-samples ${LR_DECAY_SAMPLES} \ - --save ${CHECKPOINT_DIR} \ - --load ${CHECKPOINT_DIR} \ - --data-path ${DATA_PATH} \ - --data-cache-path ${DATACACHE_DIR} \ - --split 99,1,0 \ - --tokenizer-type GPTSentencePieceTokenizer \ - --tokenizer-model ${TOKENIZER_PATH} \ - --distributed-backend nccl \ - --micro-batch-size 4 \ - --global-batch-size ${GLOBAL_BATCH_SIZE} \ - --lr 2.5e-4 \ - --min-lr 2.5e-5 \ - --lr-decay-style cosine \ - --weight-decay 0.1 \ - --clip-grad 1.0 \ - --attention-dropout 0.0 \ - --hidden-dropout 0.0 \ - --disable-bias-linear \ - --normalization RMSNorm \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --log-interval 10 \ - --save-interval 2000 \ - --eval-interval 2000 \ - --eval-iters 32 \ - --bf16 \ - --use-mcore-models \ - --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ - --no-create-attention-mask-in-dataloader \ - --tensorboard-dir ${TENSORBOARD_DIR}" - -torchrun --nproc_per_node 8 ../../pretrain_mamba.py ${options} +#!/bin/bash + +# Use: ./train.sh + +MODEL_SCALE="800M" # or "8B" + +case "${MODEL_SCALE}" in + "800M") + TENSOR_MODEL_PARALLEL_SIZE=1 + NUM_LAYERS=48 + HIDDEN_SIZE=1024 + NUM_ATTENTION_HEADS=16 + GLOBAL_BATCH_SIZE=32 + ;; + "8B") + TENSOR_MODEL_PARALLEL_SIZE=4 + NUM_LAYERS=56 + HIDDEN_SIZE=4096 + NUM_ATTENTION_HEADS=32 + GLOBAL_BATCH_SIZE=8 + ;; + *) + echo "Invalid version specified" + exit 1 + ;; +esac + +DATA_PATH=$1 +TOKENIZER_PATH=$2 + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_TIMEOUT=19 +export NCCL_IB_QPS_PER_CONNECTION=4 + +CHECKPOINT_DIR="./checkpoints" +DATACACHE_DIR="./data-cache" +TENSORBOARD_DIR="./tensorboard" + +mkdir -p ${CHECKPOINT_DIR} +mkdir -p ${DATACACHE_DIR} +mkdir -p ${TENSORBOARD_DIR} + +export TRITON_CACHE_DIR="./triton-cache/" +export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" + +SEQ_LEN=4096 +TRAIN_SAMPLES=73242188 # 300B tokens / 4096 +LR_WARMUP_SAMPLES=50000 +LR_DECAY_SAMPLES=73192188 # TRAIN_SAMPLES - LR_WARMUP_SAMPLES + +options=" \ + --tensor-model-parallel-size ${TENSOR_MODEL_PARALLEL_SIZE} \ + --sequence-parallel \ + --pipeline-model-parallel-size 1 \ + --use-distributed-optimizer \ + --overlap-param-gather \ + --overlap-grad-reduce \ + --untie-embeddings-and-output-weights \ + --init-method-std 0.02 \ + --position-embedding-type none \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_ATTENTION_HEADS} \ + --group-query-attention \ + --num-query-groups 8 \ + --hybrid-attention-ratio 0.08 \ + --hybrid-mlp-ratio 0.5 \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${SEQ_LEN} \ + --train-samples ${TRAIN_SAMPLES} \ + --lr-warmup-samples ${LR_WARMUP_SAMPLES} \ + --lr-decay-samples ${LR_DECAY_SAMPLES} \ + --save ${CHECKPOINT_DIR} \ + --load ${CHECKPOINT_DIR} \ + --data-path ${DATA_PATH} \ + --data-cache-path ${DATACACHE_DIR} \ + --split 99,1,0 \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model ${TOKENIZER_PATH} \ + --distributed-backend nccl \ + --micro-batch-size 4 \ + --global-batch-size ${GLOBAL_BATCH_SIZE} \ + --lr 2.5e-4 \ + --min-lr 2.5e-5 \ + --lr-decay-style cosine \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --disable-bias-linear \ + --normalization RMSNorm \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --log-interval 10 \ + --save-interval 2000 \ + --eval-interval 2000 \ + --eval-iters 32 \ + --bf16 \ + --use-mcore-models \ + --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --no-create-attention-mask-in-dataloader \ + --tensorboard-dir ${TENSORBOARD_DIR}" + +torchrun --nproc_per_node 8 ../../pretrain_mamba.py ${options} diff --git a/examples/mixtral/hostfile_mixtral_8x7B b/examples/mixtral/hostfile_mixtral_8x7B new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/run_mixtral8x7B_1nodes.sh b/examples/mixtral/run_mixtral_8x7B_1nodes.sh old mode 100644 new mode 100755 similarity index 100% rename from run_mixtral8x7B_1nodes.sh rename to examples/mixtral/run_mixtral_8x7B_1nodes.sh diff --git a/run_GPT-MOE.sh b/examples/mixtral/run_mixtral_8x7B_multinodes.sh old mode 100644 new mode 100755 similarity index 61% rename from run_GPT-MOE.sh rename to examples/mixtral/run_mixtral_8x7B_multinodes.sh index 9c16e0a4949cff513741eaac2e808d128869a59b..0fa5dedb107025b8afecab5c9fda43600b07b617 --- a/run_GPT-MOE.sh +++ b/examples/mixtral/run_mixtral_8x7B_multinodes.sh @@ -7,13 +7,13 @@ do fi done -mpirun -np 256 --hostfile gptnodes \ +mpirun -np 32 --hostfile hostfile_mixtral_8x7B \ --allow-run-as-root \ --bind-to none \ --mca plm_rsh_no_tree_spawn 1 \ - train_GPT-MOE_567B.sh node002 --profiling=$profiling > output.log 2>&1 + train_mixtral_8x7B_multinodes.sh node066 --profiling=$profiling > output.log 2>&1 wait rm -rf CKPT -#rm -rf mixtral_dataset/my-mixtral_text_document \ No newline at end of file +#rm -rf mixtral_dataset/my-mixtral_text_document diff --git a/train_mixtral_8x7B_1nodes.sh b/examples/mixtral/train_mixtral_8x7B_1nodes.sh similarity index 87% rename from train_mixtral_8x7B_1nodes.sh rename to examples/mixtral/train_mixtral_8x7B_1nodes.sh index 3420481bb4b572a5239ce2facbe742bb95ffbf4f..6e70fb1ad5e82b1e42924f8048d8ee58fcbe5b6c 100755 --- a/train_mixtral_8x7B_1nodes.sh +++ b/examples/mixtral/train_mixtral_8x7B_1nodes.sh @@ -4,18 +4,23 @@ for para in $* do if [[ $para == --profiling* ]];then profiling=${para#*=} - export GPU_FLUSH_ON_EXECUTION=1 - export HIP_DIRECT_DISPATCH=0 fi done -source /opt/dtk/env.sh # Runs Mixtral 8x7B model +source /opt/dtk/env.sh + +# defauat env +CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )" +MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR})) +export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH +export GLOG_minloglevel=3 export CUDA_DEVICE_MAX_CONNECTIONS=1 export HSA_FORCE_FINE_GRAIN_PCIE=1 export OMP_NUM_THREADS=1 export GPU_MAX_HW_QUEUES=10 +# nccl env export NCCL_ALGO=Ring export NCCL_MIN_NCHANNELS=32 export NCCL_MAX_NCHANNELS=32 @@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7 export NCCL_NET_GDR_READ=1 export RCCL_SDMA_COPY_ENABLE=0 export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 -#export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" +export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" + +# enable BatchLinear export GROUPED_GEMM_BatchLinear=1 -export GLOG_minloglevel=3 RANK=$OMPI_COMM_WORLD_RANK LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK @@ -75,7 +81,7 @@ MOE_ARGS=( --moe-token-dispatcher-type alltoall --moe-expert-capacity-factor 0.5 --moe-pad-expert-input-to-capacity - --moe-grouped-gemm + #--moe-grouped-gemm ) DATA_ARGS=( @@ -103,25 +109,17 @@ TRAINING_ARGS=( TORCH_PROFIE_ARGS=( --profile - --profile-ranks 0 1 2 3 4 5 6 7 8 + --profile-ranks 0 1 2 3 4 5 6 7 --profile-step-start 3 --profile-step-end 4 - --profile-dir torch_prof_mixtral_1nodes + --profile-dir torch_prof_mixtral_1nodes_tp2-pp1-ep8-ep_tp1 --use-pytorch-profiler ) -HIP_PROFIE_ARGS=( - --profile - --profile-ranks 0 1 2 3 4 5 6 7 8 - --profile-step-start 4 - --profile-step-end 5 - --use-hip-profiler -) - MODEL_PARALLEL_ARGS=( --tensor-model-parallel-size 2 --pipeline-model-parallel-size 1 - --expert-model-parallel-size 2 + --expert-model-parallel-size 8 --expert-tensor-parallel-size 1 --use-distributed-optimizer --sequence-parallel @@ -159,10 +157,6 @@ APP="python3 -u pretrain_gpt.py \ if [[ $profiling == "torch" ]]; then APP+=" ${TORCH_PROFIE_ARGS[@]}" -elif [[ $profiling == "hip" ]]; then - mkdir -p hip_prof_data - APP+=" ${HIP_PROFIE_ARGS[@]}" - APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}" fi #for hygon cpu diff --git a/train_mixtral_8x7B_2nodes.sh b/examples/mixtral/train_mixtral_8x7B_multinodes.sh old mode 100644 new mode 100755 similarity index 85% rename from train_mixtral_8x7B_2nodes.sh rename to examples/mixtral/train_mixtral_8x7B_multinodes.sh index d0b393391db7e1bdda9e36d8b95e23605bc64a5e..6413c9080fc5873c1fc65da45af833f6af246608 --- a/train_mixtral_8x7B_2nodes.sh +++ b/examples/mixtral/train_mixtral_8x7B_multinodes.sh @@ -4,18 +4,23 @@ for para in $* do if [[ $para == --profiling* ]];then profiling=${para#*=} - export GPU_FLUSH_ON_EXECUTION=1 - export HIP_DIRECT_DISPATCH=0 fi done -source /opt/dtk/env.sh # Runs Mixtral 8x7B model +source /opt/dtk/env.sh + +# defauat env +CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )" +MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR})) +export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH +export GLOG_minloglevel=3 export CUDA_DEVICE_MAX_CONNECTIONS=1 export HSA_FORCE_FINE_GRAIN_PCIE=1 export OMP_NUM_THREADS=1 export GPU_MAX_HW_QUEUES=10 +# nccl env export NCCL_ALGO=Ring export NCCL_MIN_NCHANNELS=32 export NCCL_MAX_NCHANNELS=32 @@ -23,9 +28,10 @@ export NCCL_NET_GDR_LEVEL=7 export NCCL_NET_GDR_READ=1 export RCCL_SDMA_COPY_ENABLE=0 export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 -#export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" +export NCCL_TOPO_FILE="/public/home/xingjl/dependency/rccl-tests-0204/topo-input.xml" + +# enable BatchLinear export GROUPED_GEMM_BatchLinear=1 -export GLOG_minloglevel=3 RANK=$OMPI_COMM_WORLD_RANK LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK @@ -99,9 +105,6 @@ TRAINING_ARGS=( --bf16 --overlap-param-gather --overlap-grad-reduce - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 ) TORCH_PROFIE_ARGS=( @@ -109,23 +112,15 @@ TORCH_PROFIE_ARGS=( --profile-ranks 0 1 2 3 8 9 10 11 --profile-step-start 3 --profile-step-end 4 - --profile-dir torch_prof_data_mixtral_2nodes + --profile-dir torch_prof_mixtral_4nodes_tp2-pp8-ep2-ep_tp1 --use-pytorch-profiler ) -HIP_PROFIE_ARGS=( - --profile - --profile-ranks 0 1 2 3 8 9 10 11 - --profile-step-start 4 - --profile-step-end 5 - --use-hip-profiler -) - MODEL_PARALLEL_ARGS=( - --tensor-model-parallel-size 4 - --pipeline-model-parallel-size 4 + --tensor-model-parallel-size 2 + --pipeline-model-parallel-size 8 --expert-model-parallel-size 2 - --expert-tensor-parallel-size 2 + --expert-tensor-parallel-size 1 --use-distributed-optimizer --sequence-parallel ) @@ -162,10 +157,6 @@ APP="python3 -u pretrain_gpt.py \ if [[ $profiling == "torch" ]]; then APP+=" ${TORCH_PROFIE_ARGS[@]}" -elif [[ $profiling == "hip" ]]; then - mkdir -p hip_prof_data - APP+=" ${HIP_PROFIE_ARGS[@]}" - APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}" fi #for hygon cpu diff --git a/examples/mixtral/train_mixtral_8x7b_distributed.sh b/examples/mixtral/train_mixtral_8x7b_distributed.sh deleted file mode 100644 index ed44d60f5c0363eb271ab7fbb02636cc6c82c0a9..0000000000000000000000000000000000000000 --- a/examples/mixtral/train_mixtral_8x7b_distributed.sh +++ /dev/null @@ -1,116 +0,0 @@ -#!/bin/bash - -# Runs Mixtral 8x7B model - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=${MASTER_ADDR:-"localhost"} -MASTER_PORT=${MASTER_PORT:-"6000"} -NNODES=${SLURM_NNODES:-"1"} -NODE_RANK=${RANK:-"0"} -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -CHECKPOINT_PATH=$1 -TOKENIZER_MODEL=$2 -DATA_PATH=$3 - -DISTRIBUTED_ARGS=( - --nproc_per_node $GPUS_PER_NODE - --nnodes $NNODES - --node_rank $NODE_RANK - --master_addr $MASTER_ADDR - --master_port $MASTER_PORT -) - -MODEL_ARGS=( - --use-mcore-models - --disable-bias-linear - --seq-length 4096 - --max-position-embeddings 32768 - --num-layers 32 - --hidden-size 4096 - --ffn-hidden-size 14336 - --num-attention-heads 32 - --init-method-std 0.01 - --attention-dropout 0.0 - --hidden-dropout 0.0 - --normalization RMSNorm - --position-embedding-type rope - --swiglu - --untie-embeddings-and-output-weights - --group-query-attention - --num-query-groups 8 - --no-masked-softmax-fusion - --no-position-embedding - --rotary-base 1000000 -) - -MOE_ARGS=( - --num-experts 8 - --moe-router-topk 2 - --moe-router-load-balancing-type aux_loss - --moe-aux-loss-coeff 1e-2 - --moe-grouped-gemm - --moe-token-dispatcher-type alltoall - --overlap-param-gather - --overlap-grad-reduce -) - -DATA_ARGS=( - --tokenizer-type Llama2Tokenizer - --tokenizer-model ${TOKENIZER_MODEL} - --data-path $DATA_PATH - --split 99990,8,2 -) - -TRAINING_ARGS=( - --micro-batch-size 1 - --global-batch-size 256 - --lr 1e-4 - --train-iters 500000 - --lr-decay-iters 320000 - --lr-decay-style cosine - --min-lr 1.0e-5 - --weight-decay 0.1 - --lr-warmup-iters 500 - --clip-grad 1.0 - --bf16 -) - -MODEL_PARALLEL_ARGS=( - --tensor-model-parallel-size 1 - --pipeline-model-parallel-size 4 - --expert-model-parallel-size 8 - --use-distributed-optimizer - --sequence-parallel -) - -LOGGING_ARGS=( - --log-interval 1 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH \ - --tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \ - --no-load-optim \ - --no-load-rng -) - -if [ -n "${WANDB_API_KEY}" ]; then - LOGGING_ARGS+=( - --wandb-project ${WANDB_PROJECT:-"Mixtral"} - --wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"} - ) -fi - - -torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \ - ${MODEL_ARGS[@]} \ - ${MOE_ARGS[@]} \ - ${DATA_ARGS[@]} \ - ${TRAINING_ARGS[@]} \ - ${MODEL_PARALLEL_ARGS[@]} \ - ${LOGGING_ARGS[@]} diff --git a/examples/multimodal/combine_lm_vision_checkpoints.sh b/examples/multimodal/combine_lm_vision_checkpoints.sh index 52de16ecd2337ea19502cf456f88992310618bb3..b6e327705318aab84575d4ec3aec463d14ad060d 100644 --- a/examples/multimodal/combine_lm_vision_checkpoints.sh +++ b/examples/multimodal/combine_lm_vision_checkpoints.sh @@ -1,57 +1,57 @@ -#/bin/bash -MCORE_LM=$1 # -MCORE_VISION=$2 # -OUTPUT_DIR=$3 # -MODEL_TYPE=$4 # Model type. Default: Mistral CLIP example. - -if [[ $MODEL_TYPE == "nvlm" ]]; then - # NVLM TP=8 - python examples/multimodal/combine_state_dicts.py \ - --input \ - ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_04/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_04/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_05/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_05/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_06/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_06/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_07/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_07/model_optim_rng.pt \ - --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ - --output \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_04/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_05/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_06/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_07/model_optim_rng.pt -else - # Mistral CLIP example TP=4. - python examples/multimodal/combine_state_dicts.py \ - --input \ - ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ - ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ - ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ - --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ - --output \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ - ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt -fi - -echo 1 > ${OUTPUT_DIR}/latest_checkpointed_iteration.txt +#/bin/bash +MCORE_LM=$1 # +MCORE_VISION=$2 # +OUTPUT_DIR=$3 # +MODEL_TYPE=$4 # Model type. Default: Mistral CLIP example. + +if [[ $MODEL_TYPE == "nvlm" ]]; then + # NVLM TP=8 + python examples/multimodal/combine_state_dicts.py \ + --input \ + ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_04/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_04/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_05/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_05/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_06/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_06/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_07/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_07/model_optim_rng.pt \ + --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ + --output \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_04/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_05/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_06/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_07/model_optim_rng.pt +else + # Mistral CLIP example TP=4. + python examples/multimodal/combine_state_dicts.py \ + --input \ + ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ + --output \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt +fi + +echo 1 > ${OUTPUT_DIR}/latest_checkpointed_iteration.txt diff --git a/examples/multimodal/config.py b/examples/multimodal/config.py index ee404604b650d32f4535a53dfba24498d9ab4f77..2bee6715220ca30a8068b98de85c5d7f0a350b2d 100644 --- a/examples/multimodal/config.py +++ b/examples/multimodal/config.py @@ -1,200 +1,280 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from dataclasses import dataclass - -import torch - -from megatron.training.activations import fast_gelu, quick_gelu, squared_relu - - -def get_language_model_config(config): - if config.language_model_type == "llama3_8b": - config.activation_func = torch.nn.functional.silu - config.add_bias_linear = False - config.bias_activation_fusion = False - config.gated_linear_unit = True - config.apply_query_key_layer_scaling = False - config.layernorm_zero_centered_gamma = ( - False # Zero centered gamma not supported for RMSNorm - ) - config.bias_dropout_fusion = False - config.apply_rope_fusion = False - config.attention_softmax_in_fp32 = True - config.ffn_hidden_size = 14336 - elif config.language_model_type == "mistral_7b": - config.activation_func = torch.nn.functional.silu - config.add_bias_linear = False - config.bias_activation_fusion = False - config.gated_linear_unit = True - config.apply_query_key_layer_scaling = False - config.layernorm_zero_centered_gamma = ( - False # Zero centered gamma not supported for RMSNorm - ) - config.bias_dropout_fusion = False - config.apply_rope_fusion = False - config.attention_softmax_in_fp32 = True - config.ffn_hidden_size = 14336 - elif config.language_model_type == "yi-34b": - config.activation_func = torch.nn.functional.silu - config.add_bias_linear = False - config.bias_activation_fusion = False - config.gated_linear_unit = True - config.apply_query_key_layer_scaling = False - config.layernorm_zero_centered_gamma = ( - False # Zero centered gamma not supported for RMSNorm - ) - config.bias_dropout_fusion = False - config.apply_rope_fusion = False - config.attention_softmax_in_fp32 = True - config.ffn_hidden_size = 20480 - elif config.language_model_type == "qwen2.5_7B": - config.activation_func = torch.nn.functional.silu - config.add_bias_linear = False - config.add_qkv_bias = True - config.bias_activation_fusion = False - config.gated_linear_unit = True - config.apply_query_key_layer_scaling = False - config.layernorm_zero_centered_gamma = ( - False # Zero centered gamma not supported for RMSNorm - ) - config.bias_dropout_fusion = False - config.apply_rope_fusion = False - config.attention_softmax_in_fp32 = True - config.ffn_hidden_size = 18944 - elif config.language_model_type == "qwen2.0_72B": - config.activation_func = torch.nn.functional.silu - config.add_bias_linear = False - config.add_qkv_bias = True - config.bias_activation_fusion = False - config.gated_linear_unit = True - config.apply_query_key_layer_scaling = False - config.layernorm_zero_centered_gamma = ( - False # Zero centered gamma not supported for RMSNorm - ) - config.bias_dropout_fusion = False - config.apply_rope_fusion = False - config.attention_softmax_in_fp32 = True - config.ffn_hidden_size = 29568 - else: - raise ValueError(f"unknown language model type {config.language_model_type}") - - return config - - -def get_vision_model_config(config, apply_query_key_layer_scaling): - if config.vision_model_type == "clip": - config.num_layers = 24 - config.num_attention_heads = 16 - config.add_bias_linear = True - config.add_qkv_bias = True - config.hidden_size = 1024 - config.hidden_dropout = 0.0 - config.attention_dropout = 0.0 - config.ffn_hidden_size = 4096 - config.gated_linear_unit = False - config.activation_func = quick_gelu - config.kv_channels = 64 - config.num_query_groups = 16 - config.layernorm_zero_centered_gamma = False - config.apply_query_key_layer_scaling = apply_query_key_layer_scaling - config.bias_activation_fusion = False - config.bias_dropout_fusion = False - config.attention_softmax_in_fp32 = True - config.normalization = 'LayerNorm' - config.apply_rope_fusion = False - elif config.vision_model_type == "siglip": - config.num_layers = 27 - config.num_attention_heads = 16 - config.add_bias_linear = True - config.add_qkv_bias = True - config.hidden_size = 1152 - config.hidden_dropout = 0.0 - config.attention_dropout = 0.0 - config.ffn_hidden_size = 4304 - config.gated_linear_unit = False - config.activation_func = fast_gelu - config.kv_channels = 72 - config.num_query_groups = 16 - config.layernorm_zero_centered_gamma = False - config.apply_query_key_layer_scaling = apply_query_key_layer_scaling - config.bias_activation_fusion = False - config.bias_dropout_fusion = False - config.attention_softmax_in_fp32 = True - config.normalization = 'LayerNorm' - config.apply_rope_fusion = False - config.qk_layernorm = False - config.layernorm_epsilon = 1e-6 - elif config.vision_model_type == "internvit": - config.num_layers = 45 - config.num_attention_heads = 32 # Padded for TP=8. - config.num_query_groups = 32 # Padded for TP=8. - config.kv_channels = 128 - config.add_bias_linear = True - config.add_qkv_bias = False - config.hidden_size = 3200 - config.hidden_dropout = 0.0 - config.attention_dropout = 0.0 - config.ffn_hidden_size = 12800 - config.gated_linear_unit = False - config.activation_func = torch.nn.functional.gelu - config.layernorm_zero_centered_gamma = False - config.apply_query_key_layer_scaling = apply_query_key_layer_scaling - config.bias_activation_fusion = False - config.bias_dropout_fusion = False - config.attention_softmax_in_fp32 = True - config.normalization = 'RMSNorm' - config.layernorm_epsilon = 1e-6 - config.apply_rope_fusion = False - else: - raise ValueError(f"unknown vision model type {config.vision_model_type}") - - return config - - -def get_vision_projection_config(config, hidden_size): - config.gated_linear_unit = False - config.bias_activation_fusion = False - config.add_bias_linear = False - config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model. - if config.language_model_type == "llama3_8b": - config.ffn_hidden_size = 14336 - config.activation_func = torch.nn.functional.gelu - elif config.language_model_type == "mistral_7b": - config.ffn_hidden_size = 14336 - config.activation_func = torch.nn.functional.gelu - config.normalization = None - elif config.language_model_type == "yi-34b": - config.ffn_hidden_size = 20480 - config.normalization = "LayerNorm" - config.activation_func = torch.nn.functional.gelu - elif config.language_model_type == "qwen2.5_7B": - config.ffn_hidden_size = 3584 - config.activation_func = torch.nn.functional.gelu - elif config.language_model_type == "qwen2.0_72B": - config.ffn_hidden_size = 29568 - config.normalization = "LayerNorm" - config.activation_func = torch.nn.functional.gelu - else: - raise ValueError(f"unknown language model type {config.language_model_type}") - - return config - - -@dataclass -class EvaluationConfig: - """Evaluation related configuration.""" - task: str - - temperature: float = 1.0 - top_p: float = 0.0 - top_k: int = 0 - - out_seq_length: int = 32 - - output_path: str = "" - - input_image_path: str = "" - gt_path: str = "" - - num_partitions: int = 1 - partition_id: int = 0 - num_samples_per_partition: int = 0 +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass + +import torch + +from megatron.training.activations import fast_gelu, quick_gelu, squared_relu + + +def get_language_model_config(config): + if config.language_model_type == "llama3_8b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 14336 + elif config.language_model_type == "llama3.1_8b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 14336 + elif config.language_model_type == "llama3.1_70B": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 28672 + elif config.language_model_type == "mistral_7b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 14336 + elif config.language_model_type == "yi-34b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 20480 + elif config.language_model_type == "qwen2.5_7B": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.add_qkv_bias = True + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 18944 + elif config.language_model_type == "qwen2.0_72B": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.add_qkv_bias = True + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 29568 + elif config.language_model_type == "llama3.2_1b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 8192 + elif config.language_model_type.startswith("huggingface"): + # Loaded from HuggingFace config file. + pass + else: + raise ValueError(f"unknown language model type {config.language_model_type}") + + return config + + +def get_vision_model_config(config, apply_query_key_layer_scaling): + if config.vision_model_type == "clip": + config.num_layers = 24 + config.num_attention_heads = 16 + config.add_bias_linear = True + config.add_qkv_bias = True + config.hidden_size = 1024 + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + config.ffn_hidden_size = 4096 + config.gated_linear_unit = False + config.activation_func = quick_gelu + config.kv_channels = 64 + config.num_query_groups = 16 + config.layernorm_zero_centered_gamma = False + config.apply_query_key_layer_scaling = apply_query_key_layer_scaling + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.normalization = 'LayerNorm' + config.apply_rope_fusion = False + elif config.vision_model_type == "siglip": + config.num_layers = 27 + config.num_attention_heads = 16 + config.add_bias_linear = True + config.add_qkv_bias = True + config.hidden_size = 1152 + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + config.ffn_hidden_size = 4304 + config.gated_linear_unit = False + config.activation_func = fast_gelu + config.kv_channels = 72 + config.num_query_groups = 16 + config.layernorm_zero_centered_gamma = False + config.apply_query_key_layer_scaling = apply_query_key_layer_scaling + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.normalization = 'LayerNorm' + config.apply_rope_fusion = False + config.qk_layernorm = False + config.layernorm_epsilon = 1e-6 + elif config.vision_model_type == "internvit": + config.num_layers = 45 + config.num_attention_heads = ((24 // config.tensor_model_parallel_size) + 1) * config.tensor_model_parallel_size + config.num_query_groups = config.num_attention_heads + config.add_bias_linear = True + config.add_qkv_bias = False + config.hidden_size = 3200 + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + config.ffn_hidden_size = 12800 + config.gated_linear_unit = False + config.activation_func = torch.nn.functional.gelu + config.layernorm_zero_centered_gamma = False + config.apply_query_key_layer_scaling = apply_query_key_layer_scaling + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.normalization = 'RMSNorm' + config.layernorm_epsilon = 1e-6 + config.apply_rope_fusion = False + elif config.vision_model_type == "radio": + config.num_layers = 32 + config.num_attention_heads = 16 + config.add_bias_linear = True + config.add_qkv_bias = True + config.hidden_size = 1280 + config.ffn_hidden_size = 5120 + config.gated_linear_unit = False + config.activation_func = fast_gelu + config.kv_channels = 80 + config.num_query_groups = 16 + config.layernorm_zero_centered_gamma = False + config.apply_query_key_layer_scaling = apply_query_key_layer_scaling + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.normalization = 'LayerNorm' + config.apply_rope_fusion = False + config.qk_layernorm = False + config.layernorm_epsilon = 1e-6 + elif config.vision_model_type.startswith("huggingface"): + # Loaded from HuggingFace config file. + pass + else: + raise ValueError(f"unknown vision model type {config.vision_model_type}") + + return config + + +def get_vision_projection_config(config, hidden_size): + config.gated_linear_unit = False + config.bias_activation_fusion = False + config.add_bias_linear = False + config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model. + if config.language_model_type == "llama3_8b": + config.ffn_hidden_size = 14336 + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "llama3.1_8b": + config.ffn_hidden_size = 4096 + config.activation_func = torch.nn.functional.gelu + config.layernorm_epsilon = 1e-5 + config.add_bias_linear = True + config.normalization = "LayerNorm" + elif config.language_model_type == "mistral_7b": + config.ffn_hidden_size = 14336 + config.activation_func = torch.nn.functional.gelu + config.normalization = None + elif config.language_model_type == "yi-34b": + config.ffn_hidden_size = 20480 + config.normalization = "LayerNorm" + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "qwen2.5_7B": + config.ffn_hidden_size = 3584 + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "qwen2.0_72B": + config.ffn_hidden_size = 29568 + config.normalization = "LayerNorm" + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "llama3.2_1b": + config.ffn_hidden_size = 2048 + config.activation_func = torch.nn.functional.gelu + config.normalization = "LayerNorm" + elif config.language_model_type.startswith("huggingface"): + config.activation_func = torch.nn.functional.gelu + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path) + if "qwen" in hf_config.model_type: + config.ffn_hidden_size = 1536 + else: + raise ValueError(f"unknown language model type {config.language_model_type}") + + return config + + +@dataclass +class EvaluationConfig: + """Evaluation related configuration.""" + task: str + + temperature: float = 1.0 + top_p: float = 0.0 + top_k: int = 0 + + out_seq_length: int = 32 + + output_path: str = "" + + input_image_path: str = "" + gt_path: str = "" + + num_partitions: int = 1 + partition_id: int = 0 + num_samples_per_partition: int = 0 diff --git a/examples/multimodal/dataset_helpers.py b/examples/multimodal/dataset_helpers.py index ecbbc502c08bcda12d52c74eaabdbd3ffc3d774b..bd1be085ee0f4ff789f6d4281f34f03ae9707b46 100644 --- a/examples/multimodal/dataset_helpers.py +++ b/examples/multimodal/dataset_helpers.py @@ -1,814 +1,906 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import bisect -import dataclasses -import json -import re -import sys -import traceback -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union - -from image_processing import get_visual_transform -from PIL import Image -from torchvision.transforms import ToPILImage -import numpy as np -import torch - -from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN -from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings -from megatron.energon import ( - Batch, - CaptioningSample, - DefaultTaskEncoder, - OCRSample, - Sample, - SimilarityInterleavedSample, - VQASample, - MultiChoiceVQASample -) -from megatron.energon.task_encoder.base import stateless -from megatron.training import get_args, get_tokenizer - - -@dataclass -class ImageTaskSample(Sample): - __key__: str - __restore_key__: Tuple[Union[str, int, tuple], ...] - __subflavor__: Dict - __subflavors__: Dict - # (c, h, w) - imgs: List[torch.Tensor] - num_tiles: List[int] - tokens: torch.Tensor - total_len: int # Total token count in the sample, including text and image tokens - labels: torch.Tensor = None - - -@dataclass -class ImageTaskSamplePacked(Sample): - """Dataclass to store a single packed sample (not a batch). - - P = Number of sub-samples in the packed sample - seq_len = Total sequence length - num_imgs = Number of images across all samples in the packed sample - """ - - __key__: str # Sample name - __restore_key__: Tuple[Union[str, int, tuple], ...] - __subflavor__: Dict # Sample metadata. Deprecated. - __subflavors__: Dict # Sample metadata. - tokens: torch.Tensor # Input tokens packed into a single tensor (seq_len,) - labels: torch.Tensor # Target tokens packed into a single tensor (seq_len,) - imgs: List[torch.Tensor] # Input images - num_tiles: List[int] # Number of tiles for each image of each sample (num_imgs) - max_length: int # Maximum length across sub-samples. - cu_lengths: List[int] # Cumulative length of each sub-sample in this packed sample incl. text and image tokens (P,) - - -# Typing for the resulting batch data after encode_batch() -@dataclass -class ImageTaskBatchPacked(Batch): - """Dataclass to store a batch of packed samples. - - N = Batch size - P = Number of samples in the packed sample - seq_len = Maximum sequence length - num_imgs = Number of images across all samples in the packed sample - """ - - __key__: List[str] # Sample names - __restore_key__: Tuple[Union[str, int, tuple], ...] - __subflavor__: Dict # Sample metadata. Deprecated. - __subflavors__: List[Dict] # Sample metadatas. - tokens: torch.Tensor # Input tokens packed and padded (N, seq_len) - labels: torch.Tensor # Target tokens packed and padded (N, seq_len) - imgs: torch.Tensor # All image tiles stacked into a single tensor (num_tiles, C, H, W) - num_tiles: List[List[int]] # Number of tiles per image (N, num_imgs) - max_lengths: List[int] # Maximum length across sub-samples (N,) - cu_lengths: List[List[int]] # Cumulative length of each sub-sample in each packed sample of the batch (N, P) - - -# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L19 -# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. -def search_for_fit(numbers: List[int], capacity: int) -> int: - """Finds the index of largest number that fits into the knapsack with the given capacity.""" - index = bisect.bisect(numbers, capacity) - return -1 if index == 0 else (index - 1) - - -# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L27 -# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. -def greedy_knapsack(item_sizes: List[int], samples: List, max_capacity: int) -> List: - """Greedy algorithm with binary search for the knapsack problem. - - Pack as many samples as possible given a maximum capacity and capacities of individual samples. - Used if sequence packing is enabled. - """ - assert len(item_sizes) == len(samples), "sample lengths and samples must have the same length." - - knapsacks = [] - - if len(item_sizes) == 0: - return knapsacks - - # Sort sample lengths and samples together. - sorted_item_sizes, sorted_samples = zip(*sorted(zip(item_sizes, samples), key=lambda x: x[0])) - sorted_item_sizes = list(sorted_item_sizes) - sorted_samples = list(sorted_samples) - - # Check if all samples fit in the knapsack capacity. - if sorted_item_sizes[-1] > max_capacity: - raise ValueError(f"knapsack: A sample is larger {sorted_item_sizes[-1]} than the max_sequence_length {max_capacity}.") - - while sorted_item_sizes: - current_knapsack = [] - remaining_capacity = max_capacity - - while True: - idx = search_for_fit(sorted_item_sizes, remaining_capacity) - if idx == -1: - break # Can't fit more samples. - - remaining_capacity -= sorted_item_sizes[idx] - - sorted_item_sizes.pop(idx) - sample = sorted_samples.pop(idx) - current_knapsack.append(sample) - - knapsacks.append(current_knapsack) - - return knapsacks - - -class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, dict]): - """A simple task encoder for VLMs.""" - - def __init__( - self - ): - super().__init__() - - self.args = get_args() - - self.tokenizer = get_tokenizer() - with open(self.args.prompt_path, "r") as f: - self.manual_prompts = json.load(f) - self.dataloader_seq_length = self.args.dataloader_seq_length # Always return samples of this length. - self.packing_seq_length = self.args.packing_seq_length # Packing sequence length, if packing is enabled. - self.is_packing_enabled = self.args.packing_buffer_size is not None and self.args.packing_buffer_size > 0 - - if self.dataloader_seq_length and self.packing_seq_length: - assert self.dataloader_seq_length >= self.packing_seq_length, "dataloader sequence length must be greater than or equal to the packing sequence length" - - if self.is_packing_enabled: - assert self.packing_seq_length > 0, "packing sequence length must be set" - - self.num_image_embeddings_per_tile = get_num_image_embeddings( - self.args.img_h, - self.args.img_w, - self.args.patch_dim, - self.args.vision_model_type, - self.args.disable_vision_class_token, - 1, - self.args.pixel_shuffle, - self.args.use_tile_tags, - ) - - self.txt_to_token_dict = {} - - self.img_h, self.img_w = self.args.img_h, self.args.img_w - - # This map is used to reduce the number of tiles used per image if the number of tokens is - # larger than the decoder_seq_length. - self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1} - - def _get_total_seq_length(self, input_ids, num_tiles): - """Calculate expected sequence length given text tokens length and number of tiles.""" - total_num_images = len(num_tiles) - total_num_tiles = sum(num_tiles) - total_len = len(input_ids) + total_num_tiles * self.num_image_embeddings_per_tile - total_num_images - return total_len - - def _truncate_for_packing(self, input_ids, target, num_tiles): - """Truncate tokens and labels if they exceed packing sequence length.""" - total_num_images = len(num_tiles) - total_num_tiles = sum(num_tiles) - total_img_embeddings_len = total_num_tiles * self.num_image_embeddings_per_tile - max_text_tokens = self.packing_seq_length - total_img_embeddings_len + total_num_images - - input_ids = input_ids[:max_text_tokens] - target = target[:max_text_tokens] - - # If truncate causes all labels to be ignored, then skip the sample - if (target == IGNORE_INDEX).all(): - raise ValueError(f"all targets will be ignored after truncation: {input_ids}") - - return input_ids, target - - @stateless(restore_seeds=True) - def encode_sample(self, sample: Union[CaptioningSample, OCRSample, VQASample, SimilarityInterleavedSample]): - if isinstance(sample, OCRSample): - if "pdfa" in sample.__key__: - yield self.combined_ocr_encoder(sample, task_type='encode_pdf') - elif "multi" in sample.__key__: - yield self.combined_ocr_encoder(sample, task_type='_encode_ocr') - else: - yield self.combined_ocr_encoder(sample, task_type='encode_ocr_ref') - elif isinstance(sample, CaptioningSample): - yield self.encode_captioning(sample) - elif isinstance(sample, VQASample): - is_llava_training = sample.__subflavors__["is_llava_training"] if "is_llava_training" in sample.__subflavors__ else False - - if "llava" in sample.__key__ or is_llava_training: - yield self.encode_llava_pretrain(sample) - else: - yield self.encode_any_single_turn_vqa(sample) - elif isinstance(sample, SimilarityInterleavedSample): - yield self.encode_llava_sft(sample) - elif isinstance(sample, MultiChoiceVQASample): - yield self.encode_any_single_turn_vqa(sample) - else: - raise NotImplementedError("Sample format not supported", sample) - - def encode_captioning(self, sample: CaptioningSample): - """Encode CaptioningSample.""" - augment = sample.__subflavors__.get("augmentation") - - imgs = get_visual_transform( - sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment, - self.args.vision_model_type, - ) - num_tiles = [len(imgs)] - - prompt_list = self.manual_prompts["CaptioningPretraining"]["raw"] - - prompt_idx = np.random.randint(len(prompt_list)) - cur_prompt = prompt_list[prompt_idx] - cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + "\n" - - caption = sample.caption.strip() - - split_by_line_flag = sample.__subflavors__.get("SplitByLine") - if split_by_line_flag: - caption_list = caption.split('\n') - caption = np.random.choice(caption_list) - - conv = [ - # Note: no system message. - {"role": "user", "content": cur_prompt}, - {"role": "assistant", "content": caption}, - ] - - input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False) - - if self.is_packing_enabled: - input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) - - return ImageTaskSample( - __key__=sample.__key__, - __restore_key__=sample.__restore_key__, - __subflavor__=None, - __subflavors__=sample.__subflavors__, - imgs=imgs, - num_tiles=num_tiles, - tokens=torch.tensor(input_ids), - labels=torch.tensor(target), - total_len=self._get_total_seq_length(input_ids, num_tiles), - ) - - def encode_llava_pretrain(self, sample: VQASample): - """Encode pretrain sample in LLAVA style.""" - augment = sample.__subflavors__.get("augmentation", False) - - imgs = get_visual_transform( - sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment, - self.args.vision_model_type, - ) - num_tiles = [len(imgs)] - - # LLAVA training: override text-prompt with just the image. - conv = [ - # Note: no system message. - {"role": "user", "content": IMAGE_TOKEN + "\n"}, - {"role": "assistant", "content": sample.answers}, - ] - - input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False) - - if self.is_packing_enabled: - input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) - - return ImageTaskSample( - __key__=sample.__key__, - __restore_key__=sample.__restore_key__, - __subflavor__=None, - __subflavors__=sample.__subflavors__, - imgs=imgs, - num_tiles=num_tiles, - tokens=torch.tensor(input_ids), - labels=torch.tensor(target), - total_len=self._get_total_seq_length(input_ids, num_tiles), - ) - - def encode_llava_sft(self, sample: SimilarityInterleavedSample): - """Encode SFT sample.""" - augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False - has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False - - has_image = False - if hasattr(sample, "images"): - # If this is a text-only sample and we are freezing the LM, - # then use a dummy input image. - if len(sample.images) == 0 and self.args.freeze_LM: - empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255)) - sample.images.append(empty_img) - if len(sample.images) > 0 and not has_video: - has_image = True - - # Note: Some tokenizers may ignore the system prompt. - conversation = [{"role": "system", "content": "Answer the questions."}] - # Format the conversation as a list of "user" / "assistant" turns. - for text in sample.texts: - error_msg = f"unexpected role {text['from']} in {sample.texts}" - assert text["from"] in ["human", "gpt"], error_msg - conversation.append({ - "role": "user" if text["from"] == "human" else "assistant", - "content": text["value"]}) - - # Replace the image tags with IMAGE_TOKEN and count the number of image tags - number_image_tags = 0 - image_tag_ids_list = [] - for turn in conversation: - if turn["role"] == "user": - image_tag_ids = [int(x) - 1 for x in re.findall(r"", turn["content"])] - image_tag_ids_list.extend(image_tag_ids) - turn["content"] = re.sub(r"", IMAGE_TOKEN, turn["content"]) - number_image_tags += turn["content"].count(IMAGE_TOKEN) - # For videos, we replace the image tag with the video tag - if has_video: - turn["content"] = turn["content"].replace(IMAGE_TOKEN, VIDEO_TOKEN) - - # We re-order the images in sample.images according to how they appear in the conversation. - if len(image_tag_ids_list) > 0: - sample.images = [sample.images[idx] for idx in image_tag_ids_list] - - # If there is only one image, but several image tags, we assume all the tags refer to the - # same image and duplicate the image: - if len(sample.images) == 1 and number_image_tags > 1: - sample.images = sample.images * number_image_tags - - number_of_images = len(sample.images) - # Fail if there are more image or video tags than image or videos: - error_msg = ( - f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}") - assert number_image_tags <= number_of_images, error_msg - - # If there are less image of video tags than image or videos, prepend the tags to the first - # user message: - if number_image_tags < number_of_images: - for turn in conversation: - if turn["role"] == "user": - tag_to_add = VIDEO_TOKEN if has_video else IMAGE_TOKEN - turn["content"] = tag_to_add*(number_of_images-number_image_tags) + "\n" + turn["content"] - break - - input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) - - if has_image: - imgs = [] - num_tiles = [] - max_num_tiles = self.args.max_num_tiles - # We keep a buffer of 4 tokens for the question, - # the rest can be used for image tokens. - max_image_token_allowed = self.args.decoder_seq_length - len(input_ids) - 4 - # We start by extracting as many tiles per image as possible, and decrease the max - # number of tiles if there are too many image tokens. - while True: - imgs = [] - num_tiles = [] - for img in sample.images: - img_tiles = get_visual_transform( - img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles, - self.args.use_thumbnail, augment, self.args.vision_model_type) - imgs += img_tiles - num_tiles += [len(img_tiles)] - if max_num_tiles == 1: - break - if sum(num_tiles) * self.token_per_img_tile > max_image_token_allowed: - if max_num_tiles in self.num_tiles_degradation_map: - max_num_tiles = self.num_tiles_degradation_map[max_num_tiles] - else: - raise RuntimeError(( - f"Tried to decrease the number of tiles {max_num_tiles} but it's not ", - f"defined in the degradation map {self.num_tiles_degradation_map}")) - else: - break - elif has_video: - # We don't use tiling for videos to limit the number of tokens. - use_tiling=False - # Grab the selected frames of the video as a tensor with shape - # fhwc: (num_frames, num_channels, height, width). - video_fchw = sample.images[0].permute(0, 1, 2, 3) - selected_frames = torch.linspace( - 0, video_fchw.shape[0] - 1, self.args.num_frames).long() - video_fchw = video_fchw[selected_frames] - imgs = [] - for video_chw in video_fchw: - to_pil = ToPILImage() - video_chw = to_pil(video_chw) - imgs += get_visual_transform( - video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles, - self.args.use_thumbnail, augment, self.args.vision_model_type) - num_tiles = [len(imgs)] - else: - imgs = num_tiles = [] - - if self.is_packing_enabled: - input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) - - # Some final checks with respect to the number of image tokens and images on the tokenized - # conversation. There can still be errors, for instance if a non-video sample happens to - # have our pre-defined video token, or if the packing truncation removed a necessary image - # tag. - number_image_token = np.sum(input_ids == self.img_token_id) - error_msg = ( - f"Found {number_image_token} image tokens for len({num_tiles}) = {len(num_tiles)} image tiles in {conversation}.") - assert number_image_token == len(num_tiles), error_msg - error_msg = ( - f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.") - assert np.sum(num_tiles) == len(imgs), error_msg - - return ImageTaskSample( - __key__=sample.__key__, - __restore_key__=sample.__restore_key__, - __subflavor__=None, - __subflavors__=sample.__subflavors__, - imgs=imgs, - num_tiles=num_tiles, - tokens=torch.tensor(input_ids), - labels=torch.tensor(target), - total_len=self._get_total_seq_length(input_ids, num_tiles), - ) - - def encode_any_single_turn_vqa(self, sample): - """Encode MultiChoiceVQA or VQA sample.""" - augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False - has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False - - if has_video: - # Grab the selected frames of the video as a tensor with shape - # fhwc: (num_frames, height, width, num_channels). - video_fhwc = sample.image.permute(0, 2, 3, 1) - selected_frames = torch.linspace( - 0, video_fhwc.shape[0] - 1, self.args.num_frames).long() - video_frame_fhwc = video_fhwc[selected_frames] - imgs = [] - for video_frame_hwc in video_frame_fhwc: - imgs += get_visual_transform( - video_frame_hwc, self.img_h, self.img_w, - self.args.use_tiling, self.args.max_num_tiles, - self.args.use_thumbnail, augment, self.args.vision_model_type) - else: - imgs = get_visual_transform( - sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, - self.args.use_thumbnail, augment, self.args.vision_model_type, - ) - - num_tiles = [len(imgs)] - - if isinstance(sample, MultiChoiceVQASample): - cur_prompt = format_multichoice_question(sample.context, sample.choices) - if IMAGE_TOKEN not in cur_prompt: - cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt - cur_answer = format_multichoice_answer(sample.correct_choice_idx) - elif isinstance(sample, VQASample): - if 'docvqa' in sample.__key__: - prompt_list = self.manual_prompts["VQASFT"]["docvqa"] - elif sample.__subflavors__.get("VQASFT"): - prompt_list = self.manual_prompts["VQASFT"]["raw"] - else: - prompt_list = ["{}"] - - prompt_idx = np.random.randint(len(prompt_list)) - cur_prompt = prompt_list[prompt_idx] - - cur_prompt = cur_prompt.format(sample.context) - - if IMAGE_TOKEN not in cur_prompt: - cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt - - if isinstance(sample.answers, list): - answer_list = sample.answers - weight_list = np.array(sample.answer_weights).astype(np.float32) - weight_list = weight_list / np.sum(weight_list) - answer_idx = np.random.choice(weight_list.shape[0], 1, p=weight_list)[0] - cur_answer = answer_list[answer_idx] - else: - cur_answer = sample.answers - else: - raise NotImplementedError("Unsupported data type provided", sample) - - conversation = [ - {"role": "system", "content": "Answer the questions."}, - {"role": "user", "content": cur_prompt}, - {"role": "assistant", "content": str(cur_answer)}, - ] - - input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) - - if self.is_packing_enabled: - input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) - - return ImageTaskSample( - __key__=sample.__key__, - __restore_key__=sample.__restore_key__, - __subflavor__=None, - __subflavors__=sample.__subflavors__, - imgs=imgs, - num_tiles=num_tiles, - tokens=torch.tensor(input_ids), - labels=torch.tensor(target), - total_len=self._get_total_seq_length(input_ids, num_tiles), - ) - - def combined_ocr_encoder(self, sample, task_type): - """Encode OCR samples.""" - augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False - - if task_type == "encode_pdf": - sample, cur_prompt, cur_answer = self.encode_pdf_prompt(sample) - elif task_type == "encode_ocr_ref": - sample, cur_prompt, cur_answer = self.encode_ocr_ref_prompt(sample) - elif task_type == "_encode_ocr": - sample, cur_prompt, cur_answer = self.encode_ocr_prompt(sample) - - imgs = get_visual_transform( - sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, - self.args.use_thumbnail, augment, self.args.vision_model_type, - ) - num_tiles = [len(imgs)] - - conversation = [ - {"role": "system", "content": "Answer the questions."}, - {"role": "user", "content": cur_prompt}, - {"role": "assistant", "content": str(cur_answer)}, - ] - - input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) - - if self.is_packing_enabled: - input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) - - return ImageTaskSample( - __key__=sample.__key__, - __restore_key__=sample.__restore_key__, - __subflavor__=None, - __subflavors__=sample.__subflavors__, - imgs=imgs, - num_tiles=num_tiles, - tokens=torch.tensor(input_ids), - labels=torch.tensor(target), - total_len=self._get_total_seq_length(input_ids, num_tiles), - ) - - def encode_pdf_prompt(self, sample: OCRSample) -> ImageTaskSample: - """Encode OCR sample.""" - prompt_list = self.manual_prompts["DocPretraining"]["raw"] - prompt_idx = np.random.randint(len(prompt_list)) - cur_prompt = prompt_list[prompt_idx] - if IMAGE_TOKEN not in cur_prompt: - cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt - - # Make sure there is no extra IMAGE_TOKEN tag. - sample.text = sample.text.replace(IMAGE_TOKEN, "") - - caption = sample.text.strip() - - split_by_line_flag = sample.__subflavors__.get("SplitByLine") - if split_by_line_flag: - caption_list = caption.split('\n') - caption = np.random.choice(caption_list) - cur_answer = caption - - return sample, cur_prompt, cur_answer - - def encode_ocr_ref_prompt(self, sample: OCRSample) -> ImageTaskSample: - """Encode OCR sample.""" - ref = sample.text - region = sample.words_boxes - - # Make sure there is no extra IMAGE_TOKEN tag - ref = ref.replace(IMAGE_TOKEN, "") - - if len(region) == 4: - region = f"({region[0]},{region[1]}),({region[2]},{region[3]})" - else: - region = f"({region[0]},{region[1]}),({region[2]},{region[3]}),({region[4]},{region[5]}),({region[6]},{region[7]})" - - # Randomly choose between two tasks - task_idx = np.random.randint(2) - if task_idx == 0: - # Referring Grounding - prompt_list = self.manual_prompts["DocPretraining"]["referring_grounding"] - prompt_content = ref - answer = region - else: - # Grounded OCR - prompt_list = self.manual_prompts["DocPretraining"]["grounded_ocr"] - prompt_content = region - answer = ref - - prompt_idx = np.random.randint(len(prompt_list)) - cur_prompt = prompt_list[prompt_idx] - cur_prompt = cur_prompt.format(prompt_content) - if IMAGE_TOKEN not in cur_prompt: - cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt - - return sample, cur_prompt, answer - - def bbox_coord_to_label(self, text, bbox): - """Format bbox coordinates as text.""" - assert len(bbox) == 4 or len(bbox) == 8 - - # Make sure there is no extra IMAGE_TOKEN tag - text = text.replace(IMAGE_TOKEN, "") - - if len(bbox) == 4: - label_str = f"{text}({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})" - else: - label_str = f"{text}({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]}),({bbox[4]},{bbox[5]}),({bbox[6]},{bbox[7]})" - - return label_str - - def encode_ocr_prompt(self, sample: OCRSample) -> ImageTaskSample: - """Encode OCR sample.""" - if isinstance(sample.words_boxes[0], int): - answer = self.bbox_coord_to_label(sample.text, sample.words_boxes) - elif isinstance(sample.words_boxes[0], list): - answer = "" - for i, bbox in enumerate(sample.words_boxes): - answer += self.bbox_coord_to_label(sample.words_text[i], bbox) - - prompt_list = self.manual_prompts["DocPretraining"]["ocr_multi"] - prompt_idx = np.random.randint(len(prompt_list)) - cur_prompt = prompt_list[prompt_idx] - - if IMAGE_TOKEN not in cur_prompt: - cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt - cur_answer = answer - - return sample, cur_prompt, cur_answer - - def batch(self, samples: List[Union[ImageTaskSample, ImageTaskSamplePacked]]) -> ImageTaskBatchPacked: - # Stack images to [num_tiles, c, h, w]. If there are no images (text-only), then use a dummy image. - imgs = [img for s in samples for img in s.imgs] - if len(imgs) > 0: - imgs = torch.stack(imgs) - else: - imgs = torch.tensor([[0]], dtype=torch.float32) - - # If the user hasn't defined a target dataloader sequence length, then use the max along the sample lengths. - max_seq_len = self.dataloader_seq_length - if not max_seq_len: - max_seq_len = max(len(s.tokens) for s in samples) - - tokens = np.full((len(samples), max_seq_len), self.tokenizer.pad, dtype=np.int64) - # +1 to accommodate shift to left by one later. - labels = np.full((len(samples), max_seq_len + 1), self.tokenizer.pad, dtype=np.int64) - - for i, s in enumerate(samples): - # If the sample/target length exceeds the target sequence length, then truncate. - text_len = min(max_seq_len, len(s.tokens)) - target_len = min(max_seq_len+1, len(s.labels)) - - tokens[i, :text_len] = s.tokens[:text_len] - labels[i, :target_len] = s.labels[:target_len] - - num_tiles = torch.tensor([n for s in samples for n in s.num_tiles], dtype=torch.int32) - if len(num_tiles) == 0: - num_tiles = torch.tensor([[0]], dtype=torch.int32) - - # Cumulative sample lengths are needed for packing, otherwise use dummy values. - cu_lengths = torch.tensor([[0]], dtype=torch.int32) - max_lengths = torch.tensor([[0]], dtype=torch.int32) - - if self.is_packing_enabled: - cu_lengths = torch.stack([s.cu_lengths for s in samples]) - max_lengths = torch.tensor([s.max_length for s in samples], dtype=torch.int32) - - return ImageTaskBatchPacked( - __key__=[s.__key__ for s in samples], - __restore_key__=[s.__restore_key__ for s in samples], - __subflavor__=None, - __subflavors__=samples[0].__subflavors__, - tokens=tokens, - labels=labels, - imgs=imgs, - num_tiles=num_tiles, - cu_lengths=cu_lengths, - max_lengths=max_lengths, - ) - - def encode_batch(self, batch: ImageTaskBatchPacked) -> dict: - raw = dataclasses.asdict(batch) - del raw["__subflavors__"] - return raw - - def select_samples_to_pack(self, samples: List[ImageTaskSample]) -> List[List[ImageTaskSample]]: - """Selects which samples will be packed together. - - NOTE: Energon dataloader calls this method internally if packing is used. - Please see https://nvidia.github.io/Megatron-Energon/packing.html - """ - lengths = [sample.total_len for sample in samples] - - packed_samples = greedy_knapsack(lengths, samples, self.packing_seq_length) - - return packed_samples - - @stateless - def pack_selected_samples(self, samples: List[ImageTaskSample]) -> List[ImageTaskSamplePacked]: - """ - Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked. - - NOTE: Energon dataloader calls this method internally if packing is used. - Please see https://nvidia.github.io/Megatron-Energon/packing.html - - Args: - samples: List of ImageTaskSample instances to pack into one sample. - - Returns: - ImageTaskSamplePacked instance. - """ - packing_seq_len = self.packing_seq_length - - packed_tokens = [] - packed_labels = [] - packed_imgs = [] - - current_length = 0 - max_length = 0 - cu_lengths = [0] - - # Process each sample and build lists that we will concatenate to create the packed sample. - for _, sample in enumerate(samples): - sample_len = sample.total_len - - if sample_len > max_length: - max_length = sample_len - - # If adding this sample exceeds the max length, stop. - # This should not happen. The select_samples_to_pack method should have already ensured that the samples fit. - if current_length + sample_len > packing_seq_len: - raise ValueError(f"Packed sample exceeds the maximum sequence length of {packing_seq_len}: {samples}") - - # Add the sample's tokens and labels - packed_tokens.append(sample.tokens) - packed_labels.append(sample.labels) - - # Add the images - packed_imgs += sample.imgs - - current_length += sample_len - cu_lengths.append(current_length) - - # Concatenate packed tokens and labels. - packed_tokens = torch.cat(packed_tokens, dim=0) - packed_labels = torch.cat(packed_labels, dim=0) - - return ImageTaskSamplePacked( - __key__=",".join([s.__key__ for s in samples]), - __restore_key__=(), # Will be set by energon based on `samples` - __subflavor__=None, - __subflavors__=samples[0].__subflavors__, - tokens=packed_tokens, - labels=packed_labels, - imgs=packed_imgs, - cu_lengths=torch.tensor(cu_lengths, dtype=torch.int32), - max_length=max_length, - num_tiles=[n for s in samples for n in s.num_tiles], - ) - - -def print_error_handler(exc: Exception, key: Optional[str]): - print( - f"The following exception occurred in the dataloader for sample {key} and is skipped", - file=sys.stderr, - ) - traceback.print_exc() - - -def format_multichoice_question(question, multichoice_options): - """Format multi-choice question.""" - options_text = ["{}. {}\n".format(chr(ord('A') + i), option) for i, option in - zip(range(len(multichoice_options)), multichoice_options)] - options_text = "".join(options_text) - - options_text = f"{options_text}Answer with the option's letter from the given choices directly." - - return "{}\n{}".format(question, options_text) - - -def format_multichoice_answer(idx): - """Format multi-choice answer.""" - return chr(ord('A') + idx) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import bisect +import dataclasses +import json +import re +import sys +import traceback +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +from image_processing import find_closest_aspect_ratio, find_closest_area_weighted_aspect_ratio, get_visual_transform +from PIL import Image +from torchvision.transforms import ToPILImage +import numpy as np +import torch + +from energon_util import OfflineTargetAspectRatioSample, SampleListSample +from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN +from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings +from megatron.energon import ( + Batch, + CaptioningSample, + DefaultTaskEncoder, + OCRSample, + Sample, + SimilarityInterleavedSample, + VQASample, + MultiChoiceVQASample +) +from megatron.energon.task_encoder.base import stateless +from megatron.training import get_args, get_tokenizer + + +@dataclass +class ImageTaskSample(Sample): + __key__: str + __restore_key__: Tuple[Union[str, int, tuple], ...] + __subflavor__: Dict + __subflavors__: Dict + # (c, h, w) + imgs: List[torch.Tensor] + num_tiles: List[int] + tokens: torch.Tensor + total_len: int # Total token count in the sample, including text and image tokens + labels: torch.Tensor = None + + +@dataclass +class ImageTaskSamplePacked(Sample): + """Dataclass to store a single packed sample (not a batch). + + P = Number of sub-samples in the packed sample + seq_len = Total sequence length + num_imgs = Number of images across all samples in the packed sample + """ + + __key__: str # Sample name + __restore_key__: Tuple[Union[str, int, tuple], ...] + __subflavor__: Dict # Sample metadata. Deprecated. + __subflavors__: Dict # Sample metadata. + tokens: torch.Tensor # Input tokens packed into a single tensor (seq_len,) + labels: torch.Tensor # Target tokens packed into a single tensor (seq_len,) + imgs: List[torch.Tensor] # Input images + num_tiles: List[int] # Number of tiles for each image of each sample (num_imgs) + max_length: int # Maximum length across sub-samples. + cu_lengths: List[int] # Cumulative length of each sub-sample in this packed sample incl. text and image tokens (P,) + + +# Typing for the resulting batch data after encode_batch() +@dataclass +class ImageTaskBatchPacked(Batch): + """Dataclass to store a batch of packed samples. + + N = Batch size + P = Number of samples in the packed sample + seq_len = Maximum sequence length + num_imgs = Number of images across all samples in the packed sample + """ + + __key__: List[str] # Sample names + __restore_key__: Tuple[Union[str, int, tuple], ...] + __subflavor__: Dict # Sample metadata. Deprecated. + __subflavors__: List[Dict] # Sample metadatas. + tokens: torch.Tensor # Input tokens packed and padded (N, seq_len) + labels: torch.Tensor # Target tokens packed and padded (N, seq_len) + imgs: torch.Tensor # All image tiles stacked into a single tensor (num_tiles, C, H, W) + num_tiles: List[List[int]] # Number of tiles per image (N, num_imgs) + max_lengths: List[int] # Maximum length across sub-samples (N,) + cu_lengths: List[List[int]] # Cumulative length of each sub-sample in each packed sample of the batch (N, P) + + +# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L19 +# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. +def search_for_fit(numbers: List[int], capacity: int) -> int: + """Finds the index of largest number that fits into the knapsack with the given capacity.""" + index = bisect.bisect(numbers, capacity) + return -1 if index == 0 else (index - 1) + + +# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L27 +# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. +def greedy_knapsack(item_sizes: List[int], samples: List, max_capacity: int) -> List: + """Greedy algorithm with binary search for the knapsack problem. + + Pack as many samples as possible given a maximum capacity and capacities of individual samples. + Used if sequence packing is enabled. + """ + assert len(item_sizes) == len(samples), "sample lengths and samples must have the same length." + + knapsacks = [] + + if len(item_sizes) == 0: + return knapsacks + + # Sort sample lengths and samples together. + sorted_item_sizes, sorted_samples = zip(*sorted(zip(item_sizes, samples), key=lambda x: x[0])) + sorted_item_sizes = list(sorted_item_sizes) + sorted_samples = list(sorted_samples) + + # Check if all samples fit in the knapsack capacity. + if sorted_item_sizes[-1] > max_capacity: + raise ValueError(f"knapsack: A sample is larger {sorted_item_sizes[-1]} than the max_sequence_length {max_capacity}.") + + while sorted_item_sizes: + current_knapsack = [] + remaining_capacity = max_capacity + + while True: + idx = search_for_fit(sorted_item_sizes, remaining_capacity) + if idx == -1: + break # Can't fit more samples. + + remaining_capacity -= sorted_item_sizes[idx] + + sorted_item_sizes.pop(idx) + sample = sorted_samples.pop(idx) + current_knapsack.append(sample) + + knapsacks.append(current_knapsack) + + return knapsacks + + +class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, dict]): + """A simple task encoder for VLMs.""" + + def __init__( + self + ): + super().__init__() + + self.args = get_args() + + self.tokenizer = get_tokenizer() + with open(self.args.prompt_path, "r") as f: + self.manual_prompts = json.load(f) + self.dataloader_seq_length = self.args.dataloader_seq_length # Always return samples of this length. + self.packing_seq_length = self.args.packing_seq_length # Packing sequence length, if packing is enabled. + self.is_packing_enabled = self.args.packing_buffer_size is not None and self.args.packing_buffer_size > 0 + + if self.dataloader_seq_length and self.packing_seq_length: + assert self.dataloader_seq_length >= self.packing_seq_length, "dataloader sequence length must be greater than or equal to the packing sequence length" + + if self.is_packing_enabled: + assert self.packing_seq_length > 0, "packing sequence length must be set" + + self.num_image_embeddings_per_tile = get_num_image_embeddings( + self.args.img_h, + self.args.img_w, + self.args.patch_dim, + self.args.vision_model_type, + self.args.disable_vision_class_token, + 1, + self.args.pixel_shuffle, + self.args.use_tile_tags, + ) + + self.txt_to_token_dict = {} + + self.img_h, self.img_w = self.args.img_h, self.args.img_w + self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + # This map is used to reduce the number of tiles used per image if the number of tokens is + # larger than the decoder_seq_length. + self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1} + + self.find_closest_aspect_ratio_fn = ( + find_closest_area_weighted_aspect_ratio if self.args.use_area_weighted_aspect_ratio + else find_closest_aspect_ratio) + + def _get_total_seq_length(self, input_ids, num_tiles): + """Calculate expected sequence length given text tokens length and number of tiles.""" + total_num_images = len(num_tiles) + total_num_tiles = sum(num_tiles) + total_len = len(input_ids) + total_num_tiles * self.num_image_embeddings_per_tile - total_num_images + return total_len + + def _truncate_for_packing(self, input_ids, target, num_tiles): + """Truncate tokens and labels if they exceed packing sequence length.""" + total_num_images = len(num_tiles) + total_num_tiles = sum(num_tiles) + total_img_embeddings_len = total_num_tiles * self.num_image_embeddings_per_tile + max_text_tokens = self.packing_seq_length - total_img_embeddings_len + total_num_images + + input_ids = input_ids[:max_text_tokens] + target = target[:max_text_tokens] + + # If truncate causes all labels to be ignored, then skip the sample + if (target == IGNORE_INDEX).all(): + raise ValueError(f"all targets will be ignored after truncation: {input_ids}") + + return input_ids, target + + @stateless(restore_seeds=True) + def encode_sample(self, sample: Union[CaptioningSample, OCRSample, VQASample, SimilarityInterleavedSample]): + if isinstance(sample, OCRSample): + if "pdfa" in sample.__key__: + yield self.combined_ocr_encoder(sample, task_type='encode_pdf') + elif "multi" in sample.__key__: + yield self.combined_ocr_encoder(sample, task_type='_encode_ocr') + else: + yield self.combined_ocr_encoder(sample, task_type='encode_ocr_ref') + elif isinstance(sample, CaptioningSample): + yield self.encode_captioning(sample) + elif isinstance(sample, VQASample): + is_llava_training = sample.__subflavors__["is_llava_training"] if "is_llava_training" in sample.__subflavors__ else False + + if "llava" in sample.__key__ or is_llava_training: + yield self.encode_llava_pretrain(sample) + else: + yield self.encode_any_single_turn_vqa(sample) + elif isinstance(sample, SimilarityInterleavedSample): + yield self.encode_llava_sft(sample) + elif isinstance(sample, MultiChoiceVQASample): + yield self.encode_any_single_turn_vqa(sample) + # Because the SampleListSample is defined in the Megatron module but loaded by the Energon + # library, we need to resort to the more brittle check: + elif type(sample).__name__ == "SampleListSample": + yield self.encode_sample_list(sample) + else: + raise NotImplementedError("Sample format not supported", sample) + + def encode_captioning(self, sample: CaptioningSample): + """Encode CaptioningSample.""" + augment = sample.__subflavors__.get("augmentation") + + imgs = get_visual_transform( + sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment, + self.args.vision_model_type, find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn + ) + num_tiles = [len(imgs)] + + prompt_list = self.manual_prompts["CaptioningPretraining"]["raw"] + + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + "\n" + + caption = sample.caption.strip() + + split_by_line_flag = sample.__subflavors__.get("SplitByLine") + if split_by_line_flag: + caption_list = caption.split('\n') + caption = np.random.choice(caption_list) + + conv = [ + # Note: no system message. + {"role": "user", "content": cur_prompt}, + {"role": "assistant", "content": caption}, + ] + + input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False) + + if self.is_packing_enabled: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def encode_llava_pretrain(self, sample: VQASample): + """Encode pretrain sample in LLAVA style.""" + augment = sample.__subflavors__.get("augmentation", False) + + imgs = get_visual_transform( + sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment, + self.args.vision_model_type, find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn + ) + num_tiles = [len(imgs)] + + # LLAVA training: override text-prompt with just the image. + conv = [ + # Note: no system message. + {"role": "user", "content": IMAGE_TOKEN + "\n"}, + {"role": "assistant", "content": sample.answers}, + ] + + input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False) + + if self.is_packing_enabled: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def encode_sample_list(self, samples: SampleListSample): + """We encode the list of samples using encode_llava_sft on each sample.""" + error_msg = ("You probably don't want to use online packing since SampleListSample is " + "usually used along offline packing.") + assert not self.is_packing_enabled, error_msg + encoded_samples = [] + current_length = 0 + for sample in samples.samples: + encoded_sample = self.encode_llava_sft(sample, truncate_for_sample_list_packing=True) + if current_length + encoded_sample.total_len > self.packing_seq_length: + break + else: + encoded_samples.append(encoded_sample) + current_length += encoded_sample.total_len + return self.pack_selected_samples(encoded_samples) + + def encode_llava_sft(self, sample: Union[SimilarityInterleavedSample, OfflineTargetAspectRatioSample], truncate_for_sample_list_packing=False): + """Encode SFT sample.""" + augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False + has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False + + # If the target aspect ratio are provided by the dataset, we use them instead of computing + # them with the self.find_closest_aspect_ratio_fn function. + local_find_closest_aspect_ratio_fn = self.find_closest_aspect_ratio_fn + if type(sample).__name__ == "OfflineTargetAspectRatioSample": + target_aspect_ratio = tuple(sample.target_aspect_ratio[0]) + assert target_aspect_ratio is not None, "Sample of type OfflineTargetAspectRatioSample needs to define the target aspect ratio." + local_find_closest_aspect_ratio_fn = lambda *args, **kwargs: target_aspect_ratio + + has_image = False + # We infer whether the sample has image or not. + if hasattr(sample, "images") and not has_video: + # If this is a text-only sample and we are freezing the LM, + # then use a dummy input image. + if len(sample.images) == 0 and self.args.freeze_LM: + empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255)) + sample.images.append(empty_img) + if len(sample.images) > 0: + has_image = True + + # Note: Some tokenizers may ignore the system prompt. + conversation = [{"role": "system", "content": "Answer the questions."}] + # Format the conversation as a list of "user" / "assistant" turns. + for text in sample.texts: + error_msg = f"unexpected role {text['from']} in {sample.texts}" + assert text["from"] in ["human", "gpt"], error_msg + conversation.append({ + "role": "user" if text["from"] == "human" else "assistant", + "content": text["value"]}) + + # Replace the image tags with IMAGE_TOKEN and count the number of image tags + number_image_tags = 0 + image_tag_ids_list = [] + for turn in conversation: + if turn["role"] == "user": + image_tag_ids = [int(x) - 1 for x in re.findall(r"", turn["content"])] + image_tag_ids_list.extend(image_tag_ids) + turn["content"] = re.sub(r"", IMAGE_TOKEN, turn["content"]) + # For videos, we use the image token to locate where to put the frames. + if has_video: + turn["content"] = turn["content"].replace(VIDEO_TOKEN, IMAGE_TOKEN) + number_image_tags += turn["content"].count(IMAGE_TOKEN) + + # We re-order the images in sample.images according to how they appear in the conversation. + if len(image_tag_ids_list) > 0: + sample.images = [sample.images[idx] for idx in image_tag_ids_list] + + # If there is only one image, but several image tags, we assume all the tags refer to the + # same image and duplicate the image: + if not has_video and len(sample.images) == 1 and number_image_tags > 1: + sample.images = sample.images * number_image_tags + + # We currently only support one video per sample. + number_of_images = 1 if has_video else len(sample.images) + # Fail if there are more image or video tags than image or videos: + error_msg = ( + f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}") + assert number_image_tags <= number_of_images, error_msg + + # If there are less image of video tags than image or videos, prepend the tags to the first + # user message: + if number_image_tags < number_of_images: + for turn in conversation: + if turn["role"] == "user": + turn["content"] = IMAGE_TOKEN*(number_of_images-number_image_tags) + "\n" + turn["content"] + break + + input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) + + if has_image: + imgs = [] + num_tiles = [] + max_num_tiles = self.args.max_num_tiles + # We keep a buffer of 4 tokens for the question, + # the rest can be used for image tokens. + max_image_token_allowed = self.args.decoder_seq_length - len(input_ids) - 4 + # We start by extracting as many tiles per image as possible, and decrease the max + # number of tiles if there are too many image tokens. + while True: + imgs = [] + num_tiles = [] + for img in sample.images: + img_tiles = get_visual_transform( + img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=local_find_closest_aspect_ratio_fn) + imgs += img_tiles + num_tiles += [len(img_tiles)] + if max_num_tiles == 1: + break + if sum(num_tiles) * self.num_image_embeddings_per_tile > max_image_token_allowed: + if max_num_tiles in self.num_tiles_degradation_map: + max_num_tiles = self.num_tiles_degradation_map[max_num_tiles] + else: + raise RuntimeError(( + f"Tried to decrease the number of tiles {max_num_tiles} but it's not ", + f"defined in the degradation map {self.num_tiles_degradation_map}")) + else: + break + elif has_video: + # We don't use tiling for videos to limit the number of tokens. + use_tiling=False + # Grab the selected frames of the video as a tensor with shape + # fhwc: (num_frames, num_channels, height, width). + video_fchw = sample.images.frames + if video_fchw.shape[0] == 0: + raise ValueError(f"Video {sample.__key__} {sample.__restore_key__} {sample.texts} has no frames.") + selected_frames = torch.linspace( + 0, video_fchw.shape[0] - 1, self.args.num_frames).long() + video_fchw = video_fchw[selected_frames] + imgs = [] + for video_chw in video_fchw: + to_pil = ToPILImage() + video_chw = to_pil(video_chw) + imgs += get_visual_transform( + video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=local_find_closest_aspect_ratio_fn) + num_tiles = [len(imgs)] + else: + imgs = num_tiles = [] + + if self.is_packing_enabled or truncate_for_sample_list_packing: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + # Some final checks with respect to the number of image tokens and images on the tokenized + # conversation. There can still be errors, for instance if a non-video sample happens to + # have our pre-defined video token, or if the packing truncation removed a necessary image + # tag. + number_image_token = np.sum(input_ids == self.img_token_id) + error_msg = ( + f"Found {number_image_token} image tokens for len({num_tiles}) = {len(num_tiles)} image tiles in {conversation}.") + assert number_image_token == len(num_tiles), error_msg + error_msg = ( + f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.") + assert np.sum(num_tiles) == len(imgs), error_msg + + # We need to ensure that there are at least some trainable tokens in the sample. + assert self.target_has_trainable_tokens(input_ids, num_tiles, target), "Sample has no trainable tokens." + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def target_has_trainable_tokens(self, input_ids, num_tiles, target): + # Compute the loss mask based on extending the image tags with the proper + # number of image tokens, extracting the first self.args.decoder_seq_length tokens, and + # ensuring that some of these tokens have a loss mask > 0. + # Note that this is a bit hacky because we reproduce here parts of the logics which are in + # the model itself. Ideally, the data sampler would return the already processed inputs + # and targets to avoid this duplication. + expanded_target = target.copy() + expanded_target[input_ids==self.img_token_id] = self.img_token_id + expanded_target = self.replace_value_with_repetition( + expanded_target, self.img_token_id, + self.num_image_embeddings_per_tile * np.array(num_tiles), IGNORE_INDEX) + loss_mask = torch.ones(torch.tensor(expanded_target).size(), dtype=torch.float) + loss_mask[expanded_target == self.tokenizer.pad] = 0.0 # mask paddings + loss_mask[expanded_target == IGNORE_INDEX] = 0.0 # mask prompts + loss_mask = torch.cat((loss_mask[1:], torch.zeros((1,)))) + loss_mask = loss_mask[:self.args.decoder_seq_length] + return torch.sum(loss_mask) > 0 + + def replace_value_with_repetition(self, arr, token_to_replace, num_repetition, new_token): + """ + Replace every occurrence of value V in the input array with R repetitions of W. + + Args: + arr (Array): Input array to be modified + token_to_replace: token to be replaced + new_token: new token + num_repetition (Array): number of repetition of new token. + + Returns: + Array: New array with token_to_replace replaced by num_repetition repetitions of + new_token + """ + error_msg = "The number of image tokens must match the length of the tile tensor." + assert np.sum(arr==token_to_replace) == len(num_repetition), error_msg + result = [] + idx = 0 + for item in arr: + if item == token_to_replace: + # If the current item matches token_to_replace, add R copies of W + result.extend([new_token] * num_repetition[idx]) + idx += 1 + else: + # Otherwise, keep the original item + result.append(item) + + return np.array(result) + + def encode_any_single_turn_vqa(self, sample): + """Encode MultiChoiceVQA or VQA sample.""" + augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False + has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False + + if has_video: + # Grab the selected frames of the video as a tensor with shape + # fhwc: (num_frames, height, width, num_channels). + video_fhwc = sample.image.permute(0, 2, 3, 1) + selected_frames = torch.linspace( + 0, video_fhwc.shape[0] - 1, self.args.num_frames).long() + video_frame_fhwc = video_fhwc[selected_frames] + imgs = [] + for video_frame_hwc in video_frame_fhwc: + imgs += get_visual_transform( + video_frame_hwc, self.img_h, self.img_w, + self.args.use_tiling, self.args.max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn) + else: + imgs = get_visual_transform( + sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn + ) + + num_tiles = [len(imgs)] + + if isinstance(sample, MultiChoiceVQASample): + cur_prompt = format_multichoice_question(sample.context, sample.choices) + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + cur_answer = format_multichoice_answer(sample.correct_choice_idx) + elif isinstance(sample, VQASample): + if 'docvqa' in sample.__key__: + prompt_list = self.manual_prompts["VQASFT"]["docvqa"] + elif sample.__subflavors__.get("VQASFT"): + prompt_list = self.manual_prompts["VQASFT"]["raw"] + else: + prompt_list = ["{}"] + + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + + cur_prompt = cur_prompt.format(sample.context) + + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + + if isinstance(sample.answers, list): + answer_list = sample.answers + weight_list = np.array(sample.answer_weights).astype(np.float32) + weight_list = weight_list / np.sum(weight_list) + answer_idx = np.random.choice(weight_list.shape[0], 1, p=weight_list)[0] + cur_answer = answer_list[answer_idx] + else: + cur_answer = sample.answers + else: + raise NotImplementedError("Unsupported data type provided", sample) + + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": cur_prompt}, + {"role": "assistant", "content": str(cur_answer)}, + ] + + input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) + + if self.is_packing_enabled: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def combined_ocr_encoder(self, sample, task_type): + """Encode OCR samples.""" + augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False + + if task_type == "encode_pdf": + sample, cur_prompt, cur_answer = self.encode_pdf_prompt(sample) + elif task_type == "encode_ocr_ref": + sample, cur_prompt, cur_answer = self.encode_ocr_ref_prompt(sample) + elif task_type == "_encode_ocr": + sample, cur_prompt, cur_answer = self.encode_ocr_prompt(sample) + + imgs = get_visual_transform( + sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn + ) + num_tiles = [len(imgs)] + + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": cur_prompt}, + {"role": "assistant", "content": str(cur_answer)}, + ] + + input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) + + if self.is_packing_enabled: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def encode_pdf_prompt(self, sample: OCRSample) -> ImageTaskSample: + """Encode OCR sample.""" + prompt_list = self.manual_prompts["DocPretraining"]["raw"] + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + + # Make sure there is no extra IMAGE_TOKEN tag. + sample.text = sample.text.replace(IMAGE_TOKEN, "") + + caption = sample.text.strip() + + split_by_line_flag = sample.__subflavors__.get("SplitByLine") + if split_by_line_flag: + caption_list = caption.split('\n') + caption = np.random.choice(caption_list) + cur_answer = caption + + return sample, cur_prompt, cur_answer + + def encode_ocr_ref_prompt(self, sample: OCRSample) -> ImageTaskSample: + """Encode OCR sample.""" + ref = sample.text + region = sample.words_boxes + + # Make sure there is no extra IMAGE_TOKEN tag + ref = ref.replace(IMAGE_TOKEN, "") + + if len(region) == 4: + region = f"({region[0]},{region[1]}),({region[2]},{region[3]})" + else: + region = f"({region[0]},{region[1]}),({region[2]},{region[3]}),({region[4]},{region[5]}),({region[6]},{region[7]})" + + # Randomly choose between two tasks + task_idx = np.random.randint(2) + if task_idx == 0: + # Referring Grounding + prompt_list = self.manual_prompts["DocPretraining"]["referring_grounding"] + prompt_content = ref + answer = region + else: + # Grounded OCR + prompt_list = self.manual_prompts["DocPretraining"]["grounded_ocr"] + prompt_content = region + answer = ref + + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + cur_prompt = cur_prompt.format(prompt_content) + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + + return sample, cur_prompt, answer + + def bbox_coord_to_label(self, text, bbox): + """Format bbox coordinates as text.""" + assert len(bbox) == 4 or len(bbox) == 8 + + # Make sure there is no extra IMAGE_TOKEN tag + text = text.replace(IMAGE_TOKEN, "") + + if len(bbox) == 4: + label_str = f"{text}({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})" + else: + label_str = f"{text}({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]}),({bbox[4]},{bbox[5]}),({bbox[6]},{bbox[7]})" + + return label_str + + def encode_ocr_prompt(self, sample: OCRSample) -> ImageTaskSample: + """Encode OCR sample.""" + if isinstance(sample.words_boxes[0], int): + answer = self.bbox_coord_to_label(sample.text, sample.words_boxes) + elif isinstance(sample.words_boxes[0], list): + answer = "" + for i, bbox in enumerate(sample.words_boxes): + answer += self.bbox_coord_to_label(sample.words_text[i], bbox) + + prompt_list = self.manual_prompts["DocPretraining"]["ocr_multi"] + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + cur_answer = answer + + return sample, cur_prompt, cur_answer + + def batch(self, samples: List[Union[ImageTaskSample, ImageTaskSamplePacked]]) -> ImageTaskBatchPacked: + # Stack images to [num_tiles, c, h, w]. If there are no images (text-only), then use a dummy image. + imgs = [img for s in samples for img in s.imgs] + if len(imgs) > 0: + imgs = torch.stack(imgs) + else: + imgs = torch.tensor([[0]], dtype=torch.float32) + + # If the user hasn't defined a target dataloader sequence length, then use the max along the sample lengths. + max_seq_len = self.dataloader_seq_length + if not max_seq_len: + max_seq_len = max(len(s.tokens) for s in samples) + + tokens = np.full((len(samples), max_seq_len), self.tokenizer.pad, dtype=np.int64) + # +1 to accommodate shift to left by one later. + labels = np.full((len(samples), max_seq_len + 1), self.tokenizer.pad, dtype=np.int64) + + for i, s in enumerate(samples): + # If the sample/target length exceeds the target sequence length, then truncate. + text_len = min(max_seq_len, len(s.tokens)) + target_len = min(max_seq_len+1, len(s.labels)) + + tokens[i, :text_len] = s.tokens[:text_len] + labels[i, :target_len] = s.labels[:target_len] + + num_tiles = torch.tensor([n for s in samples for n in s.num_tiles], dtype=torch.int32) + if len(num_tiles) == 0: + num_tiles = torch.tensor([[0]], dtype=torch.int32) + + # Cumulative sample lengths are needed for packing, otherwise use dummy values. + cu_lengths = torch.tensor([[0]], dtype=torch.int32) + max_lengths = torch.tensor([[0]], dtype=torch.int32) + + if self.is_packing_enabled: + cu_lengths = torch.stack([s.cu_lengths for s in samples]) + max_lengths = torch.tensor([s.max_length for s in samples], dtype=torch.int32) + + return ImageTaskBatchPacked( + __key__=[s.__key__ for s in samples], + __restore_key__=[s.__restore_key__ for s in samples], + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + tokens=tokens, + labels=labels, + imgs=imgs, + num_tiles=num_tiles, + cu_lengths=cu_lengths, + max_lengths=max_lengths, + ) + + def encode_batch(self, batch: ImageTaskBatchPacked) -> dict: + raw = dataclasses.asdict(batch) + del raw["__subflavors__"] + return raw + + def select_samples_to_pack(self, samples: List[ImageTaskSample]) -> List[List[ImageTaskSample]]: + """Selects which samples will be packed together. + + NOTE: Energon dataloader calls this method internally if packing is used. + Please see https://nvidia.github.io/Megatron-Energon/packing.html + """ + lengths = [sample.total_len for sample in samples] + + packed_samples = greedy_knapsack(lengths, samples, self.packing_seq_length) + + return packed_samples + + @stateless + def pack_selected_samples(self, samples: List[ImageTaskSample]) -> List[ImageTaskSamplePacked]: + """ + Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked. + + NOTE: Energon dataloader calls this method internally if packing is used. + Please see https://nvidia.github.io/Megatron-Energon/packing.html + + Args: + samples: List of ImageTaskSample instances to pack into one sample. + + Returns: + ImageTaskSamplePacked instance. + """ + packing_seq_len = self.packing_seq_length + + packed_tokens = [] + packed_labels = [] + packed_imgs = [] + + current_length = 0 + max_length = 0 + cu_lengths = [0] + + # Process each sample and build lists that we will concatenate to create the packed sample. + for _, sample in enumerate(samples): + sample_len = sample.total_len + + if sample_len > max_length: + max_length = sample_len + + # If adding this sample exceeds the max length, stop. + # This should not happen. The select_samples_to_pack method should have already ensured that the samples fit. + if current_length + sample_len > packing_seq_len: + raise ValueError(f"Packed sample exceeds the maximum sequence length of {packing_seq_len}: {samples}") + + # Add the sample's tokens and labels + packed_tokens.append(sample.tokens) + packed_labels.append(sample.labels) + + # Add the images + packed_imgs += sample.imgs + + current_length += sample_len + cu_lengths.append(current_length) + + # Concatenate packed tokens and labels. + packed_tokens = torch.cat(packed_tokens, dim=0) + packed_labels = torch.cat(packed_labels, dim=0) + + return ImageTaskSamplePacked( + __key__=",".join([s.__key__ for s in samples]), + __restore_key__=(), # Will be set by energon based on `samples` + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + tokens=packed_tokens, + labels=packed_labels, + imgs=packed_imgs, + cu_lengths=torch.tensor(cu_lengths, dtype=torch.int32), + max_length=max_length, + num_tiles=[n for s in samples for n in s.num_tiles], + ) + + +def print_error_handler(exc: Exception, key: Optional[str]): + print( + f"The following exception occurred in the dataloader for sample {key} and is skipped", + file=sys.stderr, + ) + traceback.print_exc() + + +def format_multichoice_question(question, multichoice_options): + """Format multi-choice question.""" + options_text = ["{}. {}\n".format(chr(ord('A') + i), option) for i, option in + zip(range(len(multichoice_options)), multichoice_options)] + options_text = "".join(options_text) + + options_text = f"{options_text}Answer with the option's letter from the given choices directly." + + return "{}\n{}".format(question, options_text) + + +def format_multichoice_answer(idx): + """Format multi-choice answer.""" + return chr(ord('A') + idx) diff --git a/examples/multimodal/energon_util.py b/examples/multimodal/energon_util.py new file mode 100644 index 0000000000000000000000000000000000000000..36135ff058c2056deefdb79f8d311ea990780257 --- /dev/null +++ b/examples/multimodal/energon_util.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch +import warnings +from dataclasses import dataclass +from typing import Any, List + +from megatron.energon import Sample +from megatron.energon.epathlib.epath import EPath +from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory + + +@dataclass +class SampleListSample(Sample): + """Sample type for a list of samples of any type which needs to be packed together. + + This is useful for datasets which are packed offline. + """ + + #: The images of the sequence + samples: List[Any] + + +class SampleListWebdataset(DefaultDecoderWebdatasetFactory[SampleListSample]): + __sample_type__ = SampleListSample + + def __init__(self, path: EPath, **kwargs): + warnings.warn( + f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" + f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" + f"# remove top-level __module__ and __class__\n" + f"sample_type:\n" + f" __module__: megatron.energon\n" + f" __class__: {self.__sample_type__.__name__}\n" + f"# Keep the remaining content", + DeprecationWarning, + ) + super().__init__(path, **kwargs) + + +@dataclass +class OfflineTargetAspectRatioSample(Sample): + """Sample type for image + text samples with target aspect ratio computed offline.""" + + #: The images of the sequence + images: List[torch.Tensor] + #: The texts of the sequence + texts: List[str] + target_aspect_ratio: List[List] diff --git a/examples/multimodal/evaluation/evaluate_infovqa.py b/examples/multimodal/evaluation/evaluate_infovqa.py new file mode 100644 index 0000000000000000000000000000000000000000..2ee8d3ed464c89a230b59bcef900d48e101d379e --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_infovqa.py @@ -0,0 +1,48 @@ +import argparse +import json + +from evaluate_vqav2 import compute_vqa_accuracy +from evaluate_mmmu import get_input_output_paths + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="InfoVQA") + + results = [] + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + results.append( + { + "question_id": res["sample_id"], + "answer": res["answer"], + "gt_answer": res["gt_answer"], + } + ) + + # Make order deterministic. + # results = sorted(results, key=lambda d: d["question_id"]) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def infovqa_eval(input_path): + """Run InfoVQA evaluation.""" + result_file_path = merge_input_files(input_path) + return compute_vqa_accuracy(result_file_path, task="InfoVQA") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + avg_acc = infovqa_eval(args.input_path) + + print(f"===== InfoVQA Accuracy {avg_acc:.2f}% =====") diff --git a/examples/multimodal/evaluation/evaluate_spdocvqa.py b/examples/multimodal/evaluation/evaluate_spdocvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..a5a4fd071ae7a12bb0b49dbea2424bcc622877be --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_spdocvqa.py @@ -0,0 +1,48 @@ +import argparse +import json + +from evaluate_vqav2 import compute_vqa_accuracy +from evaluate_mmmu import get_input_output_paths + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="SPDocVQA") + + results = [] + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + results.append( + { + "question_id": res["sample_id"], + "answer": res["answer"], + "gt_answer": res["gt_answer"], + } + ) + + # Make order deterministic. + # results = sorted(results, key=lambda d: d["question_id"]) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def spdocvqa_eval(input_path): + """Run SPDocVQA evaluation.""" + result_file_path = merge_input_files(input_path) + return compute_vqa_accuracy(result_file_path, task="SPDocVQA") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + avg_acc = spdocvqa_eval(args.input_path) + + print(f"===== SPDocVQA Accuracy {avg_acc:.2f}% =====") diff --git a/examples/multimodal/evaluation/evaluate_vqav2.py b/examples/multimodal/evaluation/evaluate_vqav2.py index 7807d80723f5aa67c7fcadd695e78643fd52cb6d..42ec6e675c88f3fcad42dd38a0f7b5016cb16b5b 100644 --- a/examples/multimodal/evaluation/evaluate_vqav2.py +++ b/examples/multimodal/evaluation/evaluate_vqav2.py @@ -1,109 +1,161 @@ -import argparse -import json - -from evaluate_mmmu import get_input_output_paths -from open_flamingo.eval.vqa_metric import VQAEval - - -def merge_input_files(input_path): - """Merge input files to a format compatible with the evaluator.""" - input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2") - - results = dict() - - for input_file_path in input_file_paths: - with open(input_file_path, "r") as input_file: - for line in input_file: - res = json.loads(line) - sample_id = res["sample_id"] - - # Skip possible duplicates. - if sample_id in results: - continue - - res["question_id"] = sample_id - results[sample_id] = res - - results = list(results.values()) - - with open(output_file_path, "w") as output_file: - json.dump(results, output_file) - - return output_file_path - - -def is_number(n: str): - """Check if input is a number.""" - try: - float(n) - return True - except ValueError: - return False - - -def compute_vqa_accuracy(result_file, task): - """Compute VQA accuracy.""" - merged_results = json.load(open(result_file)) - - vqa = VQAEval(vqa=None, vqaRes=None) - all_acc = [] - for res in merged_results: - pred = res["answer"] - pred = vqa.processPunctuation(pred) - pred = vqa.processDigitArticle(pred) - - gt = res["gt_answer"] - gt = [vqa.processPunctuation(ans) for ans in gt] - gt = [vqa.processDigitArticle(ans) for ans in gt] - - # ChartQA uses relaxed accuracy: - # "We consider an answer to be correct if it is within 5% of the gold answer. - # For non-numeric answers, we still need an exact match to consider an answer to be correct." - if task == "ChartQA": - acc = 0.0 - assert len(gt) == 1, "expected exactly one groundtruth answer." - gt = gt[0] - - pred = pred.rstrip("%") - gt = gt.rstrip("%") - - if is_number(pred) and is_number(gt): - pred = float(pred) - gt = float(gt) - if pred >= (gt * 0.95) and pred <= (gt * 1.05): - acc = 1.0 - elif pred == gt: - acc = 1.0 - - all_acc.append(acc) - elif task in ("VQAv2", "TextVQA"): - num_match = sum([pred == ans for ans in gt]) - acc = min(1.0, num_match / 3.0) - all_acc.append(acc) - elif task == "AI2D": - assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}" - acc = pred == gt[0] - all_acc.append(acc) - else: - raise NotImplementedError(f"unknown task {task}") - - acc_avg = sum(all_acc) / len(all_acc) * 100 - - return acc_avg - - -def vqav2_eval(input_path): - """Run VQAv2 evaluation.""" - result_file = merge_input_files(input_path) - avg_acc = compute_vqa_accuracy(result_file, task="VQAv2") - return avg_acc - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--input-path', type=str, help="Path to input file(s)") - args = parser.parse_args() - - avg_acc = vqav2_eval(args.input_path) - - print(f"===== VQAv2 Accuracy {avg_acc:.2f}% =====") +import argparse +import json + +from evaluate_mmmu import get_input_output_paths +from open_flamingo.eval.vqa_metric import VQAEval + +# ANLS score calculation based on https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/dist.py#L1 +# and https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/score.py#L6 +# MIT License. Copyright (c) 2022 Shunsuke KITADA +def levenshtein_distance(s1: str, s2: str) -> int: + + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = list(range(len(s1) + 1)) + for i2, c2 in enumerate(s2): + dists = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + dists.append(distances[i1]) + else: + dists.append(1 + min((distances[i1], distances[i1 + 1], dists[-1]))) + distances = dists + + return distances[-1] + + +def normalized_levenshtein_distance(s1: str, s2: str) -> float: + dist = levenshtein_distance(s1, s2) + length = max(len(s1.upper()), len(s2.upper())) + return 0.0 if length == 0 else dist / length + +def similarity_function(prediction: str, gold_label: str, threshold: float) -> float: + nl_score = normalized_levenshtein_distance(prediction, gold_label) + return 1 - nl_score if nl_score < threshold else 0.0 + +def anls_score( + prediction: str, gold_labels: List[str], threshold: float = 0.5 +) -> float: + + # not case sensitive, but space sensitive + y_pred = " ".join(prediction.strip().lower().split()) + + anls_scores: List[float] = [] + for gold_label in gold_labels: + + # not case sensitive, but space sensitive + y_true = " ".join(gold_label.strip().lower().split()) + + anls_score = similarity_function(y_pred, y_true, threshold) + anls_scores.append(anls_score) + + score = max(anls_scores) + + return score + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2") + + results = dict() + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + sample_id = res["sample_id"] + + # Skip possible duplicates. + if sample_id in results: + continue + + res["question_id"] = sample_id + results[sample_id] = res + + results = list(results.values()) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def is_number(n: str): + """Check if input is a number.""" + try: + float(n) + return True + except ValueError: + return False + + +def compute_vqa_accuracy(result_file, task): + """Compute VQA accuracy.""" + merged_results = json.load(open(result_file)) + + vqa = VQAEval(vqa=None, vqaRes=None) + all_acc = [] + for res in merged_results: + pred = res["answer"] + pred = vqa.processPunctuation(pred) + pred = vqa.processDigitArticle(pred) + + gt = res["gt_answer"] + gt = [vqa.processPunctuation(ans) for ans in gt] + gt = [vqa.processDigitArticle(ans) for ans in gt] + + # ChartQA uses relaxed accuracy: + # "We consider an answer to be correct if it is within 5% of the gold answer. + # For non-numeric answers, we still need an exact match to consider an answer to be correct." + if task == "ChartQA": + acc = 0.0 + assert len(gt) == 1, "expected exactly one groundtruth answer." + gt = gt[0] + + pred = pred.rstrip("%") + gt = gt.rstrip("%") + + if is_number(pred) and is_number(gt): + pred = float(pred) + gt = float(gt) + if pred >= (gt * 0.95) and pred <= (gt * 1.05): + acc = 1.0 + elif pred == gt: + acc = 1.0 + + all_acc.append(acc) + elif task in ("VQAv2", "TextVQA"): + num_match = sum([pred == ans for ans in gt]) + acc = min(1.0, num_match / 3.0) + all_acc.append(acc) + elif task in ("SPDocVQA", "InfoVQA"): + acc = anls_score(prediction=pred, gold_labels=gt, threshold=0.5) + all_acc.append(acc) + elif task == "AI2D": + assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}" + acc = pred == gt[0] + all_acc.append(acc) + else: + raise NotImplementedError(f"unknown task {task}") + + acc_avg = sum(all_acc) / len(all_acc) * 100 + + return acc_avg + + +def vqav2_eval(input_path): + """Run VQAv2 evaluation.""" + result_file = merge_input_files(input_path) + avg_acc = compute_vqa_accuracy(result_file, task="VQAv2") + return avg_acc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + avg_acc = vqav2_eval(args.input_path) + + print(f"===== VQAv2 Accuracy {avg_acc:.2f}% =====") diff --git a/examples/multimodal/evaluation/evaluation_datasets.py b/examples/multimodal/evaluation/evaluation_datasets.py index 50a50d56871bddd9de59c3b1444186c749892db8..a2d334689c82ea9b8802cb61ad24de52b7406575 100644 --- a/examples/multimodal/evaluation/evaluation_datasets.py +++ b/examples/multimodal/evaluation/evaluation_datasets.py @@ -1,920 +1,948 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -"""Evaluation datasets.""" -import glob -import itertools -import json -import os -import re -from collections import defaultdict - -import numpy as np -import torch -from image_processing import get_visual_transform -from PIL import Image - -from megatron.training import print_rank_0 - - -def _get_partition_bounds( - total_num_samples, num_samples_per_partition, num_partitions, partition_id -): - if num_samples_per_partition == 0: - samples_per_partition = [ - int(x) for x in np.linspace(0, total_num_samples, num_partitions + 1) - ] - return samples_per_partition[partition_id], samples_per_partition[partition_id + 1] - return num_samples_per_partition * partition_id, num_samples_per_partition * (partition_id + 1) - - -class VQADataset(torch.utils.data.Dataset): - """VQA evaluation dataset.""" - - def __init__( - self, - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - keys, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ): - samples = json.load(open(gt_path, encoding='utf-8')) - if "data" in samples: - samples = samples["data"] - - # Optionally, process only a subset of the input files. - if num_partitions > 0: - lb, ub = _get_partition_bounds( - len(samples), num_samples_per_partition, num_partitions, partition_id - ) - samples = samples[lb:ub] - - self._keys = keys - self._samples = samples - self._input_image_path = input_image_path - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - self._vision_model_type = vision_model_type - - def __len__(self): - return len(self._samples) - - def __getitem__(self, idx): - sample = self._samples[idx] - - img_file = "{}/{}".format(self._input_image_path, sample[self._keys["image_id"]]) - if not os.path.exists(img_file): - img_file += ".jpg" - - if not os.path.exists(img_file): - img_file = img_file.replace('.jpg', '.png') - - img = Image.open(img_file) - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - vision_model_type=self._vision_model_type, - ) - tile_count = torch.tensor([len(imgs)], dtype=torch.int) - - sample_id = idx - if "sample_id" in self._keys: - sample_id = sample[self._keys["sample_id"]] - - metadata = "" # Not used. - - return ( - torch.stack(imgs), - tile_count, - sample_id, - sample[self._keys["question"]], - sample[self._keys["answer"]], - metadata, - ) - - -class CaptioningDataset(torch.utils.data.Dataset): - """Captioning evaluation dataset.""" - - def __init__( - self, - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ): - image_files = sorted(glob.glob(input_image_path + "/*")) - - # Optionally, process only a subset of the input files. - if num_partitions > 0: - lb, ub = _get_partition_bounds( - len(image_files), num_samples_per_partition, num_partitions, partition_id - ) - image_files = image_files[lb:ub] - - gts = json.load(open(gt_path)) - answers = defaultdict(list) - for gt in gts["annotations"]: - answers[gt["image_id"]].append(gt['caption']) - - self._image_files = image_files - self._answers = answers - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - self._vision_model_type = vision_model_type - - def __len__(self): - return len(self._image_files) - - def __getitem__(self, idx): - img_file = self._image_files[idx] - image_id = int(img_file.split("_")[-1].split(".")[0]) - - img = Image.open(img_file) - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - vision_model_type=self._vision_model_type, - ) - - tile_count = torch.tensor([len(imgs)], dtype=torch.int) - - question = "" # Fixed for all samples. - metadata = "" # Not used. - - return torch.stack(imgs), tile_count, image_id, question, self._answers[image_id], metadata - - -class MMMUDataset(torch.utils.data.Dataset): - """MMMU evaluation dataset.""" - - def __init__( - self, - input_image_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - prompt_style, - vision_model_type, - ): - import datasets - from MMMU.mmmu.utils.data_utils import CAT_SHORT2LONG, load_yaml - - # The following downloads the MMMU dataset from HuggingFace and uses the API from the MMMU github repo to run MMMU evaluation. - all_mmmu_datasets = [] - - hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] - assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." - - for subject in CAT_SHORT2LONG.values(): - # Use a local copy of the dataset if exists (can be faster) or the HF one. - if os.path.exists(input_image_path): - subject_dataset = datasets.load_dataset( - os.path.join(input_image_path, subject), - split=datasets.Split.VALIDATION, - cache_dir=hf_datasets_cache, - verification_mode="no_checks", - ) - else: - subject_dataset = datasets.load_dataset( - "MMMU/MMMU", - subject, - split=datasets.Split.VALIDATION, - cache_dir=hf_datasets_cache, - ) - - all_mmmu_datasets.append(subject_dataset) - - dataset = datasets.concatenate_datasets(all_mmmu_datasets) - - dataset = [s for s in dataset if s['id'].startswith("val")] - - # Optionally, process only a subset of the input files. - if num_partitions > 0: - lb, ub = _get_partition_bounds( - len(dataset), num_samples_per_partition, num_partitions, partition_id - ) - dataset = dataset[lb:ub] - - # Using the LLaVA config from the MMMU repo. - config = load_yaml("examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml") - for k, v in config.items(): - if isinstance(v, list): - assert len(v) == 1, "only one value supported." - config[k] = v[0] - - self._config = config - - self._dataset = dataset - - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - self._prompt_style = prompt_style - self._vision_model_type = vision_model_type - - def __len__(self): - return len(self._dataset) - - def __getitem__(self, idx): - from MMMU.mmmu.utils.data_utils import construct_prompt, process_single_sample - - sample = self._dataset[idx] - - # Use the single image approach from the MMMU repo. - if self._prompt_style == "single_image": - sample = process_single_sample(sample) - sample = construct_prompt(sample, self._config) - - img = sample["image"] - sample_imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - vision_model_type=self._vision_model_type, - ) - sample_num_tiles = [len(sample_imgs)] - - prompt = sample["final_input_prompt"] - for i in range(8): - prompt = prompt.replace(f"", "") - sample["final_input_prompt"] = f"\n{prompt}" - elif self._prompt_style == "vlmevalkit": - sample = construct_prompt(sample, self._config) - - if sample["question_type"] == "multiple-choice": - question = sample["question"] - - options = "" - for k, v in sample["index2ans"].items(): - options += f"{k}. {v}\n" - - final_prompt = f"{question}\n" - if "hint" in sample: - final_prompt += f"Hint: {sample['hint']}\n" - - if "task_instructions" in sample: - final_prompt += f"Task instructions: {sample['task_instructions']}\n" - - final_prompt += options - final_prompt += "Answer with the option's letter from the given choices directly." - - sample["final_input_prompt"] = final_prompt.rstrip() - else: - question = sample["question"] - final_prompt = f"{question}\n" - final_prompt += "Answer the question directly." - sample["final_input_prompt"] = final_prompt.rstrip() - - sample_imgs = [] - sample_num_tiles = [] - - img_indices = sorted(list(set(re.findall(r"" - - img = sample[img_key] - assert img is not None, f"{img_str} is in prompt but not in sample images" - - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - adjusted_max_num_tiles, - self._use_thumbnail, - augment=False, - vision_model_type=self._vision_model_type, - ) # List of tiles. - - sample_imgs.extend(imgs) - sample_num_tiles.append(len(imgs)) - - sample["final_input_prompt"] = " ".join([f'' for i in range(len(img_indices))]) + "\n" + sample["final_input_prompt"] - elif self._prompt_style == "multi_image": - sample = construct_prompt(sample, self._config) - - sample_imgs = [] - sample_num_tiles = [] - - img_indices = re.findall(r"" - - img = sample[img_key] - assert img is not None, f"{img_str} is in prompt but not in sample images" - - # Note: Only replace the current image tag. - sample["final_input_prompt"] = sample["final_input_prompt"].replace( - img_str, "", 1 - ) - - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - adjusted_max_num_tiles, - self._use_thumbnail, - augment=False, - vision_model_type=self._vision_model_type, - ) # List of tiles. - - sample_imgs.extend(imgs) - sample_num_tiles.append(len(imgs)) - - # Sanity check. - for i in range(1, 8): - assert ( - f"" not in sample["final_input_prompt"] - ), "prompt contains unhandled image tags" - else: - raise ValueError(f"unknown prompt style {self._prompt_style}") - - # MMMU specific metadata. - metadata = {"question_type": sample["question_type"]} - if sample["question_type"] == "multiple-choice": - metadata["index2ans"] = sample["index2ans"] - metadata["all_choices"] = sample["all_choices"] - - prompt = sample['final_input_prompt'] - - tile_count = torch.tensor(sample_num_tiles, dtype=torch.int) - - return ( - torch.stack(sample_imgs), - tile_count, - sample["id"], - prompt, - sample["answer"], - metadata, - ) - - -class VideoMMMEDataset(torch.utils.data.Dataset): - "Video MME evaluation dataset." - - def __init__( - self, - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - num_frames, - vision_model_type, - ): - ground_truth_original = json.load(open(gt_path)) - ground_truth = [] - for gt in ground_truth_original: - video_path = gt["url"] - video_path = video_path.replace("https://www.youtube.com/watch?v=", "") - video_path = video_path.replace("https://m.youtube.com/watch?v=", "") - video_path = os.path.join(input_image_path, video_path + ".mp4") - if not os.path.exists(video_path): - continue - gt["video_path"] = video_path - ground_truth.append(gt) - - ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"]) - print_rank_0(f"Found {len(ground_truth)} videos to process.") - - if num_partitions > 0: - start_idx, end_idx = _get_partition_bounds( - len(ground_truth), num_samples_per_partition, num_partitions, partition_id - ) - ground_truth = ground_truth[start_idx:end_idx] - - self._ground_truth = ground_truth - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - self._num_frames = num_frames - self._vision_model_type = vision_model_type - - def __len__(self): - return len(self._ground_truth) - - def __getitem__(self, idx): - from torchvision.io import read_video - - gt = self._ground_truth[idx] - - video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec') - video = video.numpy() - selected_frames = torch.linspace(0, video.shape[0] - 1, self._num_frames).long() - video_frames = video[selected_frames] - if self._num_frames == 1: - video_frames = video_frames[None] - - imgs = list( - itertools.chain.from_iterable( - get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - vision_model_type=self._vision_model_type, - ) - for img in video_frames - ) - ) - - for question in gt["questions"]: - # Very hacky, but we essentially re-create gt holding only the - # question of interest. This is the make this generation script - # compatible with the Video MME evaluation script. - question_dict = { - "video_id": gt["video_id"], - "duration_category": gt["duration_category"], - "video_category": gt["video_category"], - "video_subcategory": gt["video_subcategory"], - "url": gt["url"], - "questions": [question], - } - - num_tiles = torch.tensor([len(imgs)], dtype=torch.int) - - answer = "" - metadata = "" - - return ( - torch.stack(imgs), - num_tiles, - question["question_id"], - question_dict, - answer, - metadata, - ) - - -class OCRBenchDataset(torch.utils.data.Dataset): - """OCRBench evaluation dataset.""" - - def __init__( - self, - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ): - gt = json.load(open(gt_path, encoding='utf-8')) - - if num_partitions > 0: - start_idx, end_idx = _get_partition_bounds( - len(gt), num_samples_per_partition, num_partitions, partition_id - ) - gt = gt[start_idx:end_idx] - - self._input_image_path = input_image_path - self._gt = gt - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - self._vision_model_type = vision_model_type - - def __len__(self): - return len(self._gt) - - def __getitem__(self, idx): - img_path = os.path.join(self._input_image_path, self._gt[idx]['image_path']) - - img = Image.open(img_path) - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - vision_model_type=self._vision_model_type, - ) - - tile_count = torch.tensor([len(imgs)], dtype=torch.int) - - metadata = { - "dataset_name": self._gt[idx]["dataset_name"], - "data_type": self._gt[idx]["type"], - } - - return ( - torch.stack(imgs), - tile_count, - idx, - self._gt[idx]["question"], - self._gt[idx]["answers"], - metadata, - ) - - -class MathVistaDataset(torch.utils.data.Dataset): - """MathVista evaluation dataset.""" - - def __init__( - self, - input_image_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ): - import datasets - - hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] - assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." - - if os.path.exists(input_image_path): - dataset = datasets.load_dataset( - input_image_path, cache_dir=hf_datasets_cache, verification_mode="no_checks" - ) - else: - dataset = datasets.load_dataset( - "AI4Math/MathVista", split="testmini", cache_dir=hf_datasets_cache - ) - - if num_partitions > 0: - start_idx, end_idx = _get_partition_bounds( - len(dataset), num_samples_per_partition, num_partitions, partition_id - ) - dataset = dataset[start_idx:end_idx] - - self._dataset = dataset - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - self._vision_model_type = vision_model_type - - def __len__(self): - return len(self._dataset["pid"]) - - def __getitem__(self, idx): - # Already a PIL object. - img = self._dataset['decoded_image'][idx] - - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - vision_model_type=self._vision_model_type, - ) - - tile_count = torch.tensor([len(imgs)], dtype=torch.int) - - question_id = self._dataset["pid"][idx] - question = self._dataset["question"][idx] - question_type = self._dataset["question_type"][idx] # free_form or multi_choice - query = self._dataset["query"][idx] - choices = self._dataset["choices"][idx] - answer = self._dataset["answer"][idx] - - if question_type == 'multi_choice': - start_chr = 'A' - choices_str = '' - index2ans = {} - all_choices = [] - for choice in choices: - all_choices.append(start_chr) - index2ans[start_chr] = choice - choices_str += f"{start_chr}. {choice}\n" - start_chr = chr(ord(start_chr) + 1) - - question = question + '\n' + choices_str - question = question + "Answer with the option's letter from the given choices directly." - answer = chr(ord('A') + choices.index(answer)) - else: - question = query.replace("Hint: ", "") - index2ans = {} - all_choices = [] - - metadata = { - "question_type": question_type, - "index2ans": index2ans, - "all_choices": all_choices, - } - - return torch.stack(imgs), tile_count, question_id, question, answer, metadata - - -class AI2DDataset(torch.utils.data.Dataset): - """AI2D evaluation dataset.""" - - def __init__( - self, - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - no_mask, - vision_model_type, - ): - with open(gt_path, 'r') as f: - jsonl = list(f) - - gt = [json.loads(json_str) for json_str in jsonl] - - if num_partitions > 0: - start_idx, end_idx = _get_partition_bounds( - len(gt), num_samples_per_partition, num_partitions, partition_id - ) - gt = gt[start_idx:end_idx] - - self._gt = gt - self._input_image_path = input_image_path - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - self._no_mask = no_mask - self._vision_model_type = vision_model_type - - def __len__(self): - return len(self._gt) - - def __getitem__(self, idx): - img_path = os.path.join(self._input_image_path, self._gt[idx]['image']) - if self._no_mask: - img_path.replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") - - img = Image.open(img_path) - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - vision_model_type=self._vision_model_type, - ) - - tile_count = torch.tensor([len(imgs)], dtype=torch.int) - - metadata = "" # Not used. - - return ( - torch.stack(imgs), - tile_count, - self._gt[idx]["question_id"], - self._gt[idx]["question"], - self._gt[idx]["answer"], - metadata, - ) - - -def get_evaluation_dataset( - task, - input_image_path, - gt_path, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - num_samples_per_partition, - num_partitions, - partition_id, - num_frames, - vision_model_type, -): - """Get an evaluation dataset.""" - if task == "TextVQA": - keys = { - "image_id": "image_id", - "sample_id": "question_id", - "question": "question", - "answer": "answers", - } - - dataset = VQADataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - keys, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ) - elif task == "VQAv2": - keys = { - "image_id": "image", - "sample_id": "question_id", - "question": "question", - "answer": "answer", - } - - dataset = VQADataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - keys, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ) - elif task == "ChartQA": - keys = {"image_id": "imgname", "question": "query", "answer": "label"} - - dataset = VQADataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - keys, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ) - elif task == "captioning": - dataset = CaptioningDataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ) - elif task == 'MMMU': - # Note: - # - prompt_style="single_image" uses only one image like in the MMMU repo example. - # - prompt_style="multi_image" uses multiple input images. - # - prompt_style="vlmevalkit" is similar to https://github.com/open-compass/VLMEvalKit/blob/5d3cebcf18ef4bfbadc3bd3ef80bdc7aad2c6557/vlmeval/vlm/internvl_chat.py#L499 - dataset = MMMUDataset( - input_image_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - prompt_style="single_image", - vision_model_type=vision_model_type, - ) - elif task == "VideoMME": - dataset = VideoMMMEDataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - num_frames, - vision_model_type, - ) - elif task == "OCRBench": - dataset = OCRBenchDataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ) - elif task == "MathVista": - dataset = MathVistaDataset( - input_image_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - vision_model_type, - ) - elif task == "AI2D": - dataset = AI2DDataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - no_mask=False, - vision_model_type=vision_model_type, - ) - else: - raise NotImplementedError(f"unsupported task {task}") - - return dataset +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Evaluation datasets.""" +import glob +import itertools +import json +import os +import re +from collections import defaultdict + +import numpy as np +import torch +from image_processing import get_visual_transform +from PIL import Image + +from megatron.training import print_rank_0 + + +def _get_partition_bounds( + total_num_samples, num_samples_per_partition, num_partitions, partition_id +): + if num_samples_per_partition == 0: + samples_per_partition = [ + int(x) for x in np.linspace(0, total_num_samples, num_partitions + 1) + ] + return samples_per_partition[partition_id], samples_per_partition[partition_id + 1] + return num_samples_per_partition * partition_id, num_samples_per_partition * (partition_id + 1) + + +class VQADataset(torch.utils.data.Dataset): + """VQA evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ): + samples = json.load(open(gt_path, encoding='utf-8')) + if "data" in samples: + samples = samples["data"] + + # Optionally, process only a subset of the input files. + if num_partitions > 0: + lb, ub = _get_partition_bounds( + len(samples), num_samples_per_partition, num_partitions, partition_id + ) + samples = samples[lb:ub] + + self._keys = keys + self._samples = samples + self._input_image_path = input_image_path + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._samples) + + def __getitem__(self, idx): + sample = self._samples[idx] + + img_file = "{}/{}".format(self._input_image_path, sample[self._keys["image_id"]]) + if not os.path.exists(img_file): + img_file += ".jpg" + + if not os.path.exists(img_file): + img_file = img_file.replace('.jpg', '.png') + + img = Image.open(img_file) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + sample_id = idx + if "sample_id" in self._keys: + sample_id = sample[self._keys["sample_id"]] + + metadata = "" # Not used. + + return ( + torch.stack(imgs), + tile_count, + sample_id, + sample[self._keys["question"]], + sample[self._keys["answer"]], + metadata, + ) + + +class CaptioningDataset(torch.utils.data.Dataset): + """Captioning evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ): + image_files = sorted(glob.glob(input_image_path + "/*")) + + # Optionally, process only a subset of the input files. + if num_partitions > 0: + lb, ub = _get_partition_bounds( + len(image_files), num_samples_per_partition, num_partitions, partition_id + ) + image_files = image_files[lb:ub] + + gts = json.load(open(gt_path)) + answers = defaultdict(list) + for gt in gts["annotations"]: + answers[gt["image_id"]].append(gt['caption']) + + self._image_files = image_files + self._answers = answers + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._image_files) + + def __getitem__(self, idx): + img_file = self._image_files[idx] + image_id = int(img_file.split("_")[-1].split(".")[0]) + + img = Image.open(img_file) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + question = "" # Fixed for all samples. + metadata = "" # Not used. + + return torch.stack(imgs), tile_count, image_id, question, self._answers[image_id], metadata + + +class MMMUDataset(torch.utils.data.Dataset): + """MMMU evaluation dataset.""" + + def __init__( + self, + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + prompt_style, + vision_model_type, + ): + import datasets + from MMMU.mmmu.utils.data_utils import CAT_SHORT2LONG, load_yaml + + # The following downloads the MMMU dataset from HuggingFace and uses the API from the MMMU github repo to run MMMU evaluation. + all_mmmu_datasets = [] + + hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] + assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." + + for subject in CAT_SHORT2LONG.values(): + # Use a local copy of the dataset if exists (can be faster) or the HF one. + if os.path.exists(input_image_path): + subject_dataset = datasets.load_dataset( + os.path.join(input_image_path, subject), + split=datasets.Split.VALIDATION, + cache_dir=hf_datasets_cache, + verification_mode="no_checks", + ) + else: + subject_dataset = datasets.load_dataset( + "MMMU/MMMU", + subject, + split=datasets.Split.VALIDATION, + cache_dir=hf_datasets_cache, + ) + + all_mmmu_datasets.append(subject_dataset) + + dataset = datasets.concatenate_datasets(all_mmmu_datasets) + + dataset = [s for s in dataset if s['id'].startswith("val")] + + # Optionally, process only a subset of the input files. + if num_partitions > 0: + lb, ub = _get_partition_bounds( + len(dataset), num_samples_per_partition, num_partitions, partition_id + ) + dataset = dataset[lb:ub] + + # Using the LLaVA config from the MMMU repo. + config = load_yaml("examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml") + for k, v in config.items(): + if isinstance(v, list): + assert len(v) == 1, "only one value supported." + config[k] = v[0] + + self._config = config + + self._dataset = dataset + + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._prompt_style = prompt_style + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, idx): + from MMMU.mmmu.utils.data_utils import construct_prompt, process_single_sample + + sample = self._dataset[idx] + + # Use the single image approach from the MMMU repo. + if self._prompt_style == "single_image": + sample = process_single_sample(sample) + sample = construct_prompt(sample, self._config) + + img = sample["image"] + sample_imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + sample_num_tiles = [len(sample_imgs)] + + prompt = sample["final_input_prompt"] + for i in range(8): + prompt = prompt.replace(f"", "") + sample["final_input_prompt"] = f"\n{prompt}" + elif self._prompt_style == "vlmevalkit": + sample = construct_prompt(sample, self._config) + + if sample["question_type"] == "multiple-choice": + question = sample["question"] + + options = "" + for k, v in sample["index2ans"].items(): + options += f"{k}. {v}\n" + + final_prompt = f"{question}\n" + if "hint" in sample: + final_prompt += f"Hint: {sample['hint']}\n" + + if "task_instructions" in sample: + final_prompt += f"Task instructions: {sample['task_instructions']}\n" + + final_prompt += options + final_prompt += "Answer with the option's letter from the given choices directly." + + sample["final_input_prompt"] = final_prompt.rstrip() + else: + question = sample["question"] + final_prompt = f"{question}\n" + final_prompt += "Answer the question directly." + sample["final_input_prompt"] = final_prompt.rstrip() + + sample_imgs = [] + sample_num_tiles = [] + + img_indices = sorted(list(set(re.findall(r"" + + img = sample[img_key] + assert img is not None, f"{img_str} is in prompt but not in sample images" + + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + adjusted_max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) # List of tiles. + + sample_imgs.extend(imgs) + sample_num_tiles.append(len(imgs)) + + sample["final_input_prompt"] = " ".join([f'' for i in range(len(img_indices))]) + "\n" + sample["final_input_prompt"] + elif self._prompt_style == "multi_image": + sample = construct_prompt(sample, self._config) + + sample_imgs = [] + sample_num_tiles = [] + + img_indices = re.findall(r"" + + img = sample[img_key] + assert img is not None, f"{img_str} is in prompt but not in sample images" + + # Note: Only replace the current image tag. + sample["final_input_prompt"] = sample["final_input_prompt"].replace( + img_str, "", 1 + ) + + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + adjusted_max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) # List of tiles. + + sample_imgs.extend(imgs) + sample_num_tiles.append(len(imgs)) + + # Sanity check. + for i in range(1, 8): + assert ( + f"" not in sample["final_input_prompt"] + ), "prompt contains unhandled image tags" + else: + raise ValueError(f"unknown prompt style {self._prompt_style}") + + # MMMU specific metadata. + metadata = {"question_type": sample["question_type"]} + if sample["question_type"] == "multiple-choice": + metadata["index2ans"] = sample["index2ans"] + metadata["all_choices"] = sample["all_choices"] + + prompt = sample['final_input_prompt'] + + tile_count = torch.tensor(sample_num_tiles, dtype=torch.int) + + return ( + torch.stack(sample_imgs), + tile_count, + sample["id"], + prompt, + sample["answer"], + metadata, + ) + + +class VideoMMEDataset(torch.utils.data.Dataset): + "Video MME evaluation dataset." + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_frames, + vision_model_type, + ): + ground_truth_original = json.load(open(gt_path)) + ground_truth = [] + for gt in ground_truth_original: + video_path = gt["url"] + video_path = video_path.replace("https://www.youtube.com/watch?v=", "") + video_path = video_path.replace("https://m.youtube.com/watch?v=", "") + video_path = os.path.join(input_image_path, video_path + ".mp4") + if not os.path.exists(video_path): + continue + gt["video_path"] = video_path + ground_truth.append(gt) + + ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"]) + print_rank_0(f"Found {len(ground_truth)} videos to process.") + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(ground_truth), num_samples_per_partition, num_partitions, partition_id + ) + ground_truth = ground_truth[start_idx:end_idx] + + self._ground_truth = ground_truth + self._img_h = img_h + self._img_w = img_w + self._use_tiling = False + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._num_frames = num_frames + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._ground_truth) + + def __getitem__(self, idx): + from torchvision.io import read_video + + gt = self._ground_truth[idx] + + video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec') + video = video.numpy() + selected_frames = torch.linspace(0, video.shape[0] - 1, self._num_frames).long() + video_frames = video[selected_frames] + if self._num_frames == 1: + video_frames = video_frames[None] + + imgs = [] + for img in video_frames: + from torchvision.transforms import ToPILImage + to_pil = ToPILImage() + img = to_pil(img) + imgs += get_visual_transform( + img, self._img_h, self._img_w, self._use_tiling, self._max_num_tiles, + self._use_thumbnail, augment=False, vision_model_type=self._vision_model_type + ) + + for question in gt["questions"]: + # Very hacky, but we essentially re-create gt holding only the + # question of interest. This is the make this generation script + # compatible with the Video MME evaluation script. + question_dict = { + "video_id": gt["video_id"], + "duration_category": gt["duration_category"], + "video_category": gt["video_category"], + "video_subcategory": gt["video_subcategory"], + "url": gt["url"], + "questions": [question], + } + + num_tiles = torch.tensor([len(imgs)], dtype=torch.int) + + answer = "" + metadata = "" + + return ( + torch.stack(imgs), + num_tiles, + question["question_id"], + question_dict, + answer, + metadata, + ) + + +class OCRBenchDataset(torch.utils.data.Dataset): + """OCRBench evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ): + gt = json.load(open(gt_path, encoding='utf-8')) + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(gt), num_samples_per_partition, num_partitions, partition_id + ) + gt = gt[start_idx:end_idx] + + self._input_image_path = input_image_path + self._gt = gt + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._gt) + + def __getitem__(self, idx): + img_path = os.path.join(self._input_image_path, self._gt[idx]['image_path']) + + img = Image.open(img_path) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + metadata = { + "dataset_name": self._gt[idx]["dataset_name"], + "data_type": self._gt[idx]["type"], + } + + return ( + torch.stack(imgs), + tile_count, + idx, + self._gt[idx]["question"], + self._gt[idx]["answers"], + metadata, + ) + + +class MathVistaDataset(torch.utils.data.Dataset): + """MathVista evaluation dataset.""" + + def __init__( + self, + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ): + import datasets + + hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] + assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." + + if os.path.exists(input_image_path): + dataset = datasets.load_dataset( + input_image_path, cache_dir=hf_datasets_cache, verification_mode="no_checks" + ) + else: + dataset = datasets.load_dataset( + "AI4Math/MathVista", split="testmini", cache_dir=hf_datasets_cache + ) + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(dataset), num_samples_per_partition, num_partitions, partition_id + ) + dataset = dataset[start_idx:end_idx] + + self._dataset = dataset + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._dataset["pid"]) + + def __getitem__(self, idx): + # Already a PIL object. + img = self._dataset['decoded_image'][idx] + + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + question_id = self._dataset["pid"][idx] + question = self._dataset["question"][idx] + question_type = self._dataset["question_type"][idx] # free_form or multi_choice + query = self._dataset["query"][idx] + choices = self._dataset["choices"][idx] + answer = self._dataset["answer"][idx] + + if question_type == 'multi_choice': + start_chr = 'A' + choices_str = '' + index2ans = {} + all_choices = [] + for choice in choices: + all_choices.append(start_chr) + index2ans[start_chr] = choice + choices_str += f"{start_chr}. {choice}\n" + start_chr = chr(ord(start_chr) + 1) + + question = question + '\n' + choices_str + question = question + "Answer with the option's letter from the given choices directly." + answer = chr(ord('A') + choices.index(answer)) + else: + question = query.replace("Hint: ", "") + index2ans = {} + all_choices = [] + + metadata = { + "question_type": question_type, + "index2ans": index2ans, + "all_choices": all_choices, + } + + return torch.stack(imgs), tile_count, question_id, question, answer, metadata + + +class AI2DDataset(torch.utils.data.Dataset): + """AI2D evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + no_mask, + vision_model_type, + ): + with open(gt_path, 'r') as f: + jsonl = list(f) + + gt = [json.loads(json_str) for json_str in jsonl] + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(gt), num_samples_per_partition, num_partitions, partition_id + ) + gt = gt[start_idx:end_idx] + + self._gt = gt + self._input_image_path = input_image_path + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._no_mask = no_mask + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._gt) + + def __getitem__(self, idx): + img_path = os.path.join(self._input_image_path, self._gt[idx]['image']) + if self._no_mask: + img_path.replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") + + img = Image.open(img_path) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + metadata = "" # Not used. + + return ( + torch.stack(imgs), + tile_count, + self._gt[idx]["question_id"], + self._gt[idx]["question"], + self._gt[idx]["answer"], + metadata, + ) + + +def get_evaluation_dataset( + task, + input_image_path, + gt_path, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_samples_per_partition, + num_partitions, + partition_id, + num_frames, + vision_model_type, +): + """Get an evaluation dataset.""" + if task == "TextVQA": + keys = { + "image_id": "image_id", + "sample_id": "question_id", + "question": "question", + "answer": "answers", + } + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "VQAv2": + keys = { + "image_id": "image", + "sample_id": "question_id", + "question": "question", + "answer": "answer", + } + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "ChartQA": + keys = {"image_id": "imgname", "question": "query", "answer": "label"} + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "captioning": + dataset = CaptioningDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == 'MMMU': + # Note: + # - prompt_style="single_image" uses only one image like in the MMMU repo example. + # - prompt_style="multi_image" uses multiple input images. + # - prompt_style="vlmevalkit" is similar to https://github.com/open-compass/VLMEvalKit/blob/5d3cebcf18ef4bfbadc3bd3ef80bdc7aad2c6557/vlmeval/vlm/internvl_chat.py#L499 + dataset = MMMUDataset( + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + prompt_style="single_image", + vision_model_type=vision_model_type, + ) + elif task == "VideoMME": + dataset = VideoMMEDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_frames, + vision_model_type, + ) + elif task == "OCRBench": + dataset = OCRBenchDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "MathVista": + dataset = MathVistaDataset( + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "AI2D": + dataset = AI2DDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + no_mask=False, + vision_model_type=vision_model_type, + ) + elif task == "SPDocVQA": + keys = {"sample_id": "questionId", "image_id": "image", "question": "question", "answer": "answers"} + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "InfoVQA": + keys = {"sample_id": "questionId", "image_id": "image_local_name", "question": "question", "answer": "answers"} + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + else: + raise NotImplementedError(f"unsupported task {task}") + + return dataset diff --git a/examples/multimodal/image_processing.py b/examples/multimodal/image_processing.py index ed9401c6798755df49805ef3b1c557538ddb59f6..3d3365d8e4eff569acccf73e80f31f172fa6de08 100644 --- a/examples/multimodal/image_processing.py +++ b/examples/multimodal/image_processing.py @@ -1,118 +1,143 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. -from torchvision import transforms as T -from torchvision.transforms import Compose -from torchvision.transforms.functional import InterpolationMode - - -IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] -IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] -SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] -SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] -CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] -CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] - - -pixel_statistics = { - "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), - "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), - "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), -} - - -def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, vision_model_type="clip"): - pixel_mean, pixel_std = pixel_statistics[vision_model_type] - - assert not augment, "Image augmentation not implemented." - transform = build_transform(img_h, pixel_mean, pixel_std, vision_model_type) - - if use_tiling: - assert img_h == img_w, "dynamic tiling expects equal tile height and width" - imgs = dynamic_preprocess(img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail) - imgs = [transform(img) for img in imgs] - else: - imgs = [transform(img)] - - return imgs - - -# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685 -# Copyright (c) 2023 OpenGVLab. -def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): - best_ratio_diff = float('inf') - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - ratio_diff = abs(aspect_ratio - target_aspect_ratio) - if ratio_diff < best_ratio_diff: - best_ratio_diff = ratio_diff - best_ratio = ratio - elif ratio_diff == best_ratio_diff: - if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: - best_ratio = ratio - # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') - return best_ratio - - -# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702 -# Copyright (c) 2023 OpenGVLab. -def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): - orig_width, orig_height = image.size - aspect_ratio = orig_width / orig_height - - # calculate the existing image aspect ratio - target_ratios = set( - (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if - i * j <= max_num and i * j >= min_num) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - - # find the closest aspect ratio to the target - target_aspect_ratio = find_closest_aspect_ratio( - aspect_ratio, target_ratios, orig_width, orig_height, image_size) - - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - blocks = target_aspect_ratio[0] * target_aspect_ratio[1] - - # resize the image - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size - ) - # split the image - split_img = resized_img.crop(box) - processed_images.append(split_img) - assert len(processed_images) == blocks - if use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((image_size, image_size)) - processed_images.append(thumbnail_img) - return processed_images - - -# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 -# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 -def build_transform(input_size, pixel_mean, pixel_std, vision_model_type): - if vision_model_type in ("siglip", "internvit"): - transform = T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), - T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std) - ]) - elif vision_model_type == "clip": - transform = Compose([ - T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), - T.ToTensor(), - T.Normalize(mean=pixel_mean, std=pixel_std), - ]) - else: - raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}") - - return transform +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. +from torchvision import transforms as T +from torchvision.transforms import Compose +from torchvision.transforms.functional import InterpolationMode + + +IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] +IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] +SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] +SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] +CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] +CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] + + +pixel_statistics = { + "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), + "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), + "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), +} + + +# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685 +# Copyright (c) 2023 OpenGVLab. +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + """ + Find the best number of tiles based on the aspect ratio and the area covered by the tiles. + """ + best_factor = float('-inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + factor_based_on_area_n_ratio = ( + min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) * + min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)) + if factor_based_on_area_n_ratio > best_factor: + best_factor = factor_based_on_area_n_ratio + best_ratio = ratio + return best_ratio + + +def get_visual_transform( + img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, + vision_model_type="clip", find_closest_aspect_ratio_fn=find_closest_aspect_ratio): + pixel_mean, pixel_std = pixel_statistics[vision_model_type] + + assert not augment, "Image augmentation not implemented." + transform = build_transform(img_h, pixel_mean, pixel_std, vision_model_type) + + if use_tiling: + assert img_h == img_w, "dynamic tiling expects equal tile height and width" + imgs = dynamic_preprocess( + img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail, + find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn) + imgs = [transform(img) for img in imgs] + else: + imgs = [transform(img)] + + return imgs + + +# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702 +# Copyright (c) 2023 OpenGVLab. +def dynamic_preprocess( + image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, + find_closest_aspect_ratio_fn=find_closest_aspect_ratio): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 +# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 +def build_transform(input_size, pixel_mean, pixel_std, vision_model_type): + if vision_model_type in ("siglip", "internvit", "radio", "huggingface"): + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std) + ]) + elif vision_model_type == "clip": + transform = Compose([ + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + else: + raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}") + + return transform diff --git a/examples/multimodal/layer_specs.py b/examples/multimodal/layer_specs.py index 2e07dc808da06936e89da6db9562a367a8e288fc..0f170fad87775393d36c03e3e1b09078c4f27da0 100644 --- a/examples/multimodal/layer_specs.py +++ b/examples/multimodal/layer_specs.py @@ -1,135 +1,139 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import torch - -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules - -try: - from megatron.core.extensions.transformer_engine import ( - TEColumnParallelLinear, - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TENorm, - TERowParallelLinear, - ) - - HAVE_TE = True -except ImportError: - HAVE_TE = False - -try: - import apex - - from megatron.core.fusions.fused_layer_norm import FusedLayerNorm - from megatron.core.transformer.torch_norm import WrappedTorchNorm - - HAVE_APEX = True - LNImpl = FusedLayerNorm -except ImportError: - import warnings - - from megatron.core.transformer.torch_norm import WrappedTorchNorm - - warnings.warn(f'Apex is not installed. Falling back to Torch Norm') - LNImpl = WrappedTorchNorm - - -def get_layer_spec(is_vit, normalization) -> ModuleSpec: - attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal - if normalization == "LayerNorm": - norm = LNImpl - elif normalization == "RMSNorm": - if HAVE_TE: - norm = TENorm - else: - version = torch.__version__.split('.') - version_geq_2_4 = ( - int(TORCH_VERSION[0]) > 2 - or ( - int(TORCH_VERSION[0]) == 2 - and int(TORCH_VERSION[1]) >= 4 - ) - ) - assert version_geq_2_4, "Torch version >= 2.4.0 is required for RMSNorm" - if HAVE_APEX: - warnings.warn(f'Apex does not support RMSNorm. Falling back to Torch Norm') - norm = WrappedTorchNorm - else: - raise RuntimeError("unknown normalization", normalization) - - mlp = get_mlp_module_spec(use_te=False) # doesn't include norm. - - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=norm, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": attn_mask_type}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - q_layernorm=IdentityOp, - k_layernorm=IdentityOp, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=norm, - mlp=mlp, - mlp_bda=get_bias_dropout_add, - ), - ) - - -def get_layer_spec_te(is_vit=False) -> ModuleSpec: - attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal - - mlp = get_norm_mlp_module_spec_te() - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": attn_mask_type}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - q_layernorm=IdentityOp, - k_layernorm=IdentityOp, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=IdentityOp, - mlp=mlp, - mlp_bda=get_bias_dropout_add, - ), - ) - - -def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: - # Dense MLP w/ or w/o TE modules. - return ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, - linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, - ), - ) - - -def get_norm_mlp_module_spec_te() -> ModuleSpec: - return ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear - ), - ) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import apex + + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + import warnings + + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + warnings.warn(f'Apex is not installed. Falling back to Torch Norm') + LNImpl = WrappedTorchNorm + + +def get_layer_spec(is_vit, normalization) -> ModuleSpec: + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + if normalization == "LayerNorm": + norm = LNImpl + elif normalization == "RMSNorm": + if HAVE_TE: + norm = TENorm + else: + version = torch.__version__.split('.') + version_geq_2_4 = ( + int(TORCH_VERSION[0]) > 2 + or ( + int(TORCH_VERSION[0]) == 2 + and int(TORCH_VERSION[1]) >= 4 + ) + ) + assert version_geq_2_4, "Torch version >= 2.4.0 is required for RMSNorm" + if HAVE_APEX: + warnings.warn(f'Apex does not support RMSNorm. Falling back to Torch Norm') + norm = WrappedTorchNorm + else: + raise RuntimeError("unknown normalization", normalization) + + mlp = get_mlp_module_spec(use_te=False) # doesn't include norm. + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=norm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=norm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec: + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + # Padding mask is needed for e.g. Context Parallel. + if padding: + assert not is_vit, "padding_causal mask not used with ViT" + attn_mask_type = AttnMaskType.padding_causal + + mlp = get_norm_mlp_module_spec_te() + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + + +def get_norm_mlp_module_spec_te() -> ModuleSpec: + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ) diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py index a28a428325b8db9c7c1268080979889935dcc396..feca9a917188a379099db318b1ec838f7665ce31 100644 --- a/examples/multimodal/model.py +++ b/examples/multimodal/model.py @@ -1,216 +1,254 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import warnings -from copy import deepcopy - -import torch -from config import get_language_model_config, get_vision_model_config, get_vision_projection_config -from layer_specs import get_layer_spec, get_layer_spec_te, get_mlp_module_spec, get_norm_mlp_module_spec_te - -from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN, LLaVAModel -from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings -from megatron.training import get_args, get_tokenizer, print_rank_0 -from megatron.training.arguments import core_transformer_config_from_args - - -def model_provider( - pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True -) -> LLaVAModel: - """Builds the model. - - Args: - pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True. - post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True. - add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder - will live on only a subset of the pipeline stages (specifically, only the first stage). - add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder - will live on only a subset of the pipeline stages (specifically, every stage after the first one). - parallel_output (bool): Enable parallel model output. - - Returns: - model: A multimodal model. - """ - args = get_args() - assert args.ckpt_format == 'torch', "Only ckpt-format torch is supported for VLM training currently." - assert args.encoder_pipeline_model_parallel_size <= 1, "LLaVA does not support pp>1 for encoder on it's own pipeline rank" - - use_te = args.use_te - - print_rank_0('building a multimodal model ...') - - num_image_embeddings = get_num_image_embeddings( - args.img_h, - args.img_w, - args.patch_dim, - args.vision_model_type, - args.disable_vision_class_token, - 1, - args.pixel_shuffle, - args.use_tile_tags, - ) - old_seq_length = args.seq_length - args.seq_length = args.encoder_seq_length = num_image_embeddings - if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length: - warnings.warn( - f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})" - ) - - max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings - - assert ( - args.decoder_seq_length is not None - ), "Please provide --decoder-seq-length to set the language model sequence length" - assert ( - args.decoder_seq_length > max_num_image_embeddings - ), "Language model sequence length must be greater than the maximum number of image embeddings" - if args.decoder_seq_length > args.max_position_embeddings: - args.max_position_embeddings = args.decoder_seq_length - warnings.warn( - f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length" - ) - - base_config = core_transformer_config_from_args(get_args()) - base_config.language_model_type = args.language_model_type - base_config.vision_model_type = args.vision_model_type - base_config.calculate_per_token_loss = True - - language_config = deepcopy(base_config) - language_config = get_language_model_config(language_config) - - if use_te: - language_transformer_layer_spec = get_layer_spec_te( - is_vit=False - ) # TENorm detects LayerNorm/RMS automatically. - else: - language_transformer_layer_spec = get_layer_spec( - is_vit=False, normalization=language_config.normalization - ) - - vision_config = deepcopy(base_config) - vision_config = get_vision_model_config( - vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling - ) - - vision_model_type = args.vision_model_type - if vision_model_type in ["clip", "siglip"]: - if use_te: - vision_transformer_layer_spec = get_layer_spec_te( - is_vit=True - ) # TENorm detects LayerNorm/RMS automatically. - else: - vision_transformer_layer_spec = get_layer_spec( - is_vit=True, normalization=vision_config.normalization - ) - elif vision_model_type == "internvit": - from nvlm.internvit import get_internvit_layer_spec - vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te) - else: - raise RuntimeError("unsupported vision model type", vision_model_type) - - vision_projection_config = deepcopy(base_config) - vision_projection_config = get_vision_projection_config( - vision_projection_config, language_config.hidden_size - ) - - # --encoder-pipeline-model-parallel-size 1 will enable a separate pipeline stage for the vision model. - if args.encoder_pipeline_model_parallel_size > 0: - assert ( - args.encoder_pipeline_model_parallel_size == 1 - ), "vision model and projection can only live on 1 pipeline stage." - - if args.encoder_tensor_model_parallel_size > 0: - vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size - vision_projection_config.tensor_model_parallel_size = ( - args.encoder_tensor_model_parallel_size - ) - - # Make sure vision model pipeline parallel size is not inherited from the language model pipeline parallel size. - # 0 is not a valid for the config value, hence max(1, ). - vision_config.pipeline_model_parallel_size = max(1, args.encoder_pipeline_model_parallel_size) - vision_projection_config.pipeline_model_parallel_size = vision_config.pipeline_model_parallel_size - - # Make sure the vision model does not inherit first and last pipeline num layers from the language model. - vision_config.first_pipeline_num_layers = vision_config.last_pipeline_num_layers = None - - if vision_projection_config.normalization: - vision_projection_layer_spec = get_norm_mlp_module_spec_te().submodules - else: - vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules - - # Toggle --recompute* for the vision and language model separately. - if args.recompute_vision: - if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None: - vision_config.recompute_num_layers = vision_config.num_layers - else: - vision_config.recompute_granularity = None - vision_config.recompute_method = None - vision_config.recompute_num_layers = None - - vision_projection_config.recompute_granularity = None - vision_projection_config.recompute_method = None - vision_projection_config.recompute_num_layers = None - - - tokenizer = get_tokenizer() - image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - - tile_tags = _get_tile_tags(args, tokenizer) - - model = LLaVAModel( - language_transformer_config=language_config, - language_transformer_layer_spec=language_transformer_layer_spec, - language_vocab_size=args.padded_vocab_size, - language_max_sequence_length=args.decoder_seq_length, - vision_transformer_config=vision_config, - vision_transformer_layer_spec=vision_transformer_layer_spec, - drop_vision_class_token=args.disable_vision_class_token, - vision_projection_config=vision_projection_config, - vision_projection_layer_spec=vision_projection_layer_spec, - vision_projection_type="mlp", - allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint, - parallel_output=parallel_output, - language_position_embedding_type=args.position_embedding_type, - language_rotary_percent=args.rotary_percent, - pre_process=pre_process, - post_process=post_process, - add_encoder=add_encoder, - add_decoder=add_decoder, - img_h=args.img_h, - img_w=args.img_w, - patch_dim=args.patch_dim, - language_rotary_base=args.rotary_base, - language_rope_scaling=args.use_rope_scaling, - image_token_index=image_token_index, - pixel_shuffle=args.pixel_shuffle, - tile_tags=tile_tags, - ) - - model.freeze( - freeze_language_model=args.freeze_LM, - freeze_vision_model=args.freeze_ViT, - freeze_vision_projection=False, - ) - - return model - - -def _get_tile_tags(args, tokenizer): - """Tile tags are used in NVLM to surround image tiles with text tags.""" - if not args.use_tile_tags: - return None - - # We expect the tokenized length of the tags is same. - thumbnail_tag_text = "" - if args.tokenizer_prompt_format == "nvlm-yi-34b": - thumbnail_tag_text = "" - - assert args.max_num_tiles <= 6, "Up to 6 tile tags used" - tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text] - - start_idx = 0 - if tokenizer._prompt_config.has_bos: - start_idx = 1 - - # Convert to tokens [num_tiles, tile_seq_len]. - tile_tags = [tokenizer.tokenize(t)[start_idx:] for t in tile_tags_text] - - return tile_tags +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import warnings +from copy import deepcopy + +import torch +from config import get_language_model_config, get_vision_model_config, get_vision_projection_config +from layer_specs import get_layer_spec, get_layer_spec_te, get_mlp_module_spec, get_norm_mlp_module_spec_te + +from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN, LLaVAModel +from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings +from megatron.training import get_args, get_tokenizer, print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args + + +def model_provider( + pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True +) -> LLaVAModel: + """Builds the model. + + Args: + pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True. + post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True. + add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder + will live on only a subset of the pipeline stages (specifically, only the first stage). + add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder + will live on only a subset of the pipeline stages (specifically, every stage after the first one). + parallel_output (bool): Enable parallel model output. + + Returns: + model: A multimodal model. + """ + args = get_args() + assert args.encoder_pipeline_model_parallel_size <= 1, "LLaVA does not support pp>1 for encoder on it's own pipeline rank" + + use_te = args.use_te + + print_rank_0('building a multimodal model ...') + + num_image_embeddings = get_num_image_embeddings( + args.img_h, + args.img_w, + args.patch_dim, + args.vision_model_type, + args.disable_vision_class_token, + 1, + args.pixel_shuffle, + args.use_tile_tags, + ) + old_seq_length = args.seq_length + args.seq_length = args.encoder_seq_length = num_image_embeddings + if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length: + warnings.warn( + f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})" + ) + + max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings + + assert ( + args.decoder_seq_length is not None + ), "Please provide --decoder-seq-length to set the language model sequence length" + assert ( + args.decoder_seq_length > max_num_image_embeddings + ), "Language model sequence length must be greater than the maximum number of image embeddings" + if args.decoder_seq_length > args.max_position_embeddings: + args.max_position_embeddings = args.decoder_seq_length + warnings.warn( + f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length" + ) + + base_config = core_transformer_config_from_args(get_args()) + base_config.language_model_type = args.language_model_type + base_config.vision_model_type = args.vision_model_type + base_config.calculate_per_token_loss = True + + language_config = deepcopy(base_config) + language_config = get_language_model_config(language_config) + + if use_te: + # Padding mask needed for SP/CP. + padding = args.context_parallel_size > 1 and args.sequence_parallel + language_transformer_layer_spec = get_layer_spec_te( + is_vit=False, padding=padding + ) # TENorm detects LayerNorm/RMS automatically. + else: + language_transformer_layer_spec = get_layer_spec( + is_vit=False, normalization=language_config.normalization + ) + + vision_model_type = args.vision_model_type + vision_config = deepcopy(base_config) + vision_config = get_vision_model_config( + vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling + ) + if vision_model_type.startswith("huggingface"): + assert args.encoder_tensor_model_parallel_size < 2, "Huggingface vision encoders do not support --encoder-tensor-model-parallel-size > 1" + assert args.encoder_pipeline_model_parallel_size == 0, "Huggingface vision encoders do not support --encoder-pipeline-model-parallel-size > 0" + assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel" + assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1" + assert args.vision_huggingface_model_name_or_path is not None, "Providing --vision-huggingface-model-name-or-path is necessary when using huggingface vision model" + + vision_config.huggingface_model_name_or_path = args.vision_huggingface_model_name_or_path + + from transformers import AutoConfig + huggingface_config = AutoConfig.from_pretrained(vision_config.huggingface_model_name_or_path) + vision_config.hidden_size = huggingface_config.hidden_size + + vision_model_type = args.vision_model_type + if vision_model_type in ["clip", "siglip", "radio"]: + if use_te: + vision_transformer_layer_spec = get_layer_spec_te( + is_vit=True + ) # TENorm detects LayerNorm/RMS automatically. + else: + vision_transformer_layer_spec = get_layer_spec( + is_vit=True, normalization=vision_config.normalization + ) + elif vision_model_type == "internvit": + from nvlm.internvit import get_internvit_layer_spec + vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te) + elif vision_model_type.startswith("huggingface"): + vision_transformer_layer_spec = None + else: + raise RuntimeError("unsupported vision model type", vision_model_type) + + vision_projection_config = deepcopy(base_config) + + if base_config.language_model_type.startswith("huggingface"): + assert args.tensor_model_parallel_size == 1, "Huggingface models do not support --tensor-model-parallel-size > 1" + assert args.pipeline_model_parallel_size < 2, "Huggingface models do not support --pipeline-model-parallel-size > 1" + assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel" + assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1" + assert args.language_huggingface_model_name_or_path is not None, "Providing --language-huggingface-model-name-or-path is necessary when using huggingface language model" + + language_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path + # Pass to vision projection config so can choose the correct ffn hidden size + vision_projection_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path + + vision_projection_config = get_vision_projection_config( + vision_projection_config, language_config.hidden_size + ) + + # --encoder-pipeline-model-parallel-size 1 will enable a separate pipeline stage for the vision model. + if args.encoder_pipeline_model_parallel_size > 0: + assert ( + args.encoder_pipeline_model_parallel_size == 1 + ), "vision model and projection can only live on 1 pipeline stage." + + if args.encoder_tensor_model_parallel_size > 0: + vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size + vision_projection_config.tensor_model_parallel_size = ( + args.encoder_tensor_model_parallel_size + ) + + # Make sure vision model pipeline parallel size is not inherited from the language model pipeline parallel size. + # 0 is not a valid for the config value, hence max(1, ). + vision_config.pipeline_model_parallel_size = max(1, args.encoder_pipeline_model_parallel_size) + vision_projection_config.pipeline_model_parallel_size = vision_config.pipeline_model_parallel_size + + # Make sure the vision model does not inherit first and last pipeline num layers from the language model. + vision_config.first_pipeline_num_layers = vision_config.last_pipeline_num_layers = None + + if vision_projection_config.normalization: + vision_projection_layer_spec = get_norm_mlp_module_spec_te().submodules + else: + vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules + + # Toggle --recompute* for the vision and language model separately. + if args.recompute_vision: + if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None: + vision_config.recompute_num_layers = vision_config.num_layers + else: + vision_config.recompute_granularity = None + vision_config.recompute_method = None + vision_config.recompute_num_layers = None + + vision_projection_config.recompute_granularity = None + vision_projection_config.recompute_method = None + vision_projection_config.recompute_num_layers = None + + # TODO: Vision model and projection do not use SP/CP yet. + vision_config.sequence_parallel = False + vision_config.context_parallel_size = 1 + vision_config.tp_comm_overlap = False + + vision_projection_config.sequence_parallel = False + vision_projection_config.context_parallel_size = 1 + vision_projection_config.tp_comm_overlap = False + + tokenizer = get_tokenizer() + image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + assert image_token_index is not None, f"IMAGE_TOKEN={IMAGE_TOKEN} needs to be added using the --special-tokens arg." + + tile_tags = _get_tile_tags(args, tokenizer) + + model = LLaVAModel( + language_transformer_config=language_config, + language_transformer_layer_spec=language_transformer_layer_spec, + language_vocab_size=args.padded_vocab_size, + language_max_sequence_length=args.decoder_seq_length, + vision_transformer_config=vision_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + drop_vision_class_token=args.disable_vision_class_token, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_layer_spec, + vision_projection_type="mlp", + allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint, + parallel_output=parallel_output, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + language_position_embedding_type=args.position_embedding_type, + language_rotary_percent=args.rotary_percent, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, + img_h=args.img_h, + img_w=args.img_w, + patch_dim=args.patch_dim, + language_rotary_base=args.rotary_base, + language_rope_scaling=args.use_rope_scaling, + image_token_index=image_token_index, + pixel_shuffle=args.pixel_shuffle, + tile_tags=tile_tags, + ) + + model.freeze( + freeze_language_model=args.freeze_LM, + freeze_vision_model=args.freeze_ViT, + freeze_vision_projection=False, + ) + + return model + + +def _get_tile_tags(args, tokenizer): + """Tile tags are used in NVLM to surround image tiles with text tags.""" + if not args.use_tile_tags: + return None + + # We expect the tokenized length of the tags is same. + thumbnail_tag_text = "" + if args.tokenizer_prompt_format == "nvlm-yi-34b": + thumbnail_tag_text = "" + + assert args.max_num_tiles <= 6, "Up to 6 tile tags used" + tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text] + + start_idx = 0 + if tokenizer._prompt_config.has_bos: + start_idx = 1 + + # Convert to tokens [num_tiles, tile_seq_len]. + tile_tags = [tokenizer.tokenize(t)[start_idx:] for t in tile_tags_text] + + return tile_tags diff --git a/examples/multimodal/model_converter/internvit_converter.py b/examples/multimodal/model_converter/internvit_converter.py index 48404c2084cc84bead036b4ae82ce1d440dab101..544e2600e4a698ade6d20d941acd2ee6f40cde62 100644 --- a/examples/multimodal/model_converter/internvit_converter.py +++ b/examples/multimodal/model_converter/internvit_converter.py @@ -1,162 +1,162 @@ -import argparse -import os - -import torch -from transformers import AutoModel - - -def convert(model_name, output_path, tensor_parallel_size, use_te): - """Convert InternViT HF checkpoint to mcore.""" - hf_model = AutoModel.from_pretrained( - model_name, - trust_remote_code=True - ) - - hf_state_dict = hf_model.state_dict() - new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] - - hidden_size = 3200 - num_heads = 25 - dim = 128 - - order = torch.ones(3 * hidden_size).long() - - for j in range(num_heads): - for i in range(dim): - order[i + dim*3*j] = j*dim+i - order[dim + i + dim*3*j] = j*dim+i+num_heads*dim - order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2 - - for name, tensor in hf_state_dict.items(): - # Map parameter names to ones used in megatron. - new_name = "" - new_tensor = tensor - - # This is used for chunking some tensors to target tensor parallel size. - chunk_dim = None - - if "embeddings.class_embedding" in name: - new_name = "class_token" - elif "embeddings.patch_embedding.weight" in name: - new_name = "conv1.weight" - elif "embeddings.patch_embedding.bias" in name: - new_name = "conv1.bias" - elif "embeddings.position_embedding" in name: - new_name = "position_embeddings.weight" - new_tensor = new_tensor.squeeze(0) - elif "encoder.layers" in name: - layer_idx = name.split(".")[2] - - base = f"decoder.layers.{layer_idx}" - - head_dim = 128 - - if tensor_parallel_size == 1: - num_padded_heads = 25 - elif tensor_parallel_size == 8: - # Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism. - # So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model. - num_padded_heads = 32 - else: - raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size) - - if "ls1" in name: - new_name = f"{base}.ls1" - elif "ls2" in name: - new_name = f"{base}.ls2" - elif "attn.qkv.weight" in name: - new_name = f"{base}.self_attention.linear_qkv.weight" - num_tensors = 3 - padded_dim = head_dim * num_padded_heads * num_tensors - padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device) - padded_tensor[:new_tensor.shape[0], :] = new_tensor[order] - new_tensor = padded_tensor - chunk_dim = 0 - elif "attn.q_norm.weight" in name: - new_name = f"{base}.self_attention.q_layernorm.weight" - num_tensors = 1 - padded_dim = head_dim * num_padded_heads * num_tensors - padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) - padded_tensor[:new_tensor.shape[0]] = new_tensor - new_tensor = padded_tensor - chunk_dim = 0 - elif "attn.k_norm.weight" in name: - new_name = f"{base}.self_attention.k_layernorm.weight" - num_tensors = 1 - padded_dim = head_dim * num_padded_heads * num_tensors - padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) - padded_tensor[:new_tensor.shape[0]] = new_tensor - new_tensor = padded_tensor - chunk_dim = 0 - elif "attn.proj.weight" in name: - new_name = f"{base}.self_attention.linear_proj.weight" - num_tensors = 1 - padded_dim = head_dim * num_padded_heads * num_tensors - padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device) - padded_tensor[:, :new_tensor.shape[-1]] = new_tensor - new_tensor = padded_tensor - chunk_dim = 1 - elif "attn.proj.bias" in name: - new_name = f"{base}.self_attention.linear_proj.bias" - elif "mlp.fc1.weight" in name: - new_name = f"{base}.mlp.linear_fc1.weight" - chunk_dim = 0 - elif "mlp.fc1.bias" in name: - new_name = f"{base}.mlp.linear_fc1.bias" - chunk_dim = 0 - elif "mlp.fc2.weight" in name: - new_name = f"{base}.mlp.linear_fc2.weight" - chunk_dim = 1 - elif "mlp.fc2.bias" in name: - new_name = f"{base}.mlp.linear_fc2.bias" - elif "norm1" in name: - new_name = f"{base}.input_layernorm.weight" - elif "norm2" in name: - new_name = f"{base}.pre_mlp_layernorm.weight" - else: - raise RuntimeError("unexpected transformer layer name", name) - else: - raise RuntimeError("unexpected layer name", name) - - assert new_name != "", f"unexpected layer name {name}" - - # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. - extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") - is_extra_state_layer = any([l in new_name for l in extra_state_layers]) - if use_te and is_extra_state_layer: - layer = new_name.split(".")[-2] - if layer in extra_state_layers: - extra_state_name = ( - new_name[: new_name.rfind(".") + 1] + "_extra_state" - ) # Replace the weight name. - for i in range(tensor_parallel_size): - new_state_dicts[i]["model"][extra_state_name] = None - - if chunk_dim is None: - new_tensors = [new_tensor for _ in range(tensor_parallel_size)] - else: - new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) - - for i in range(tensor_parallel_size): - new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() - - for i in range(tensor_parallel_size): - output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}") - os.makedirs(output_dir_tp, exist_ok=True) - output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") - torch.save(new_state_dicts[i], output_path_tp) - print("saved file", output_path_tp) - - print("done") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter") - parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace") - parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.") - parser.add_argument("--use-te", action="store_true", default=True) - parser.add_argument("--tensor-parallel-size", type=int, required=True) - - args = parser.parse_args() - - convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te) +import argparse +import os + +import torch +from transformers import AutoModel + + +def convert(model_name, output_path, tensor_parallel_size, use_te): + """Convert InternViT HF checkpoint to mcore.""" + hf_model = AutoModel.from_pretrained( + model_name, + trust_remote_code=True + ) + + hf_state_dict = hf_model.state_dict() + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] + + hidden_size = 3200 + num_heads = 25 + dim = 128 + + order = torch.ones(3 * hidden_size).long() + + for j in range(num_heads): + for i in range(dim): + order[i + dim*3*j] = j*dim+i + order[dim + i + dim*3*j] = j*dim+i+num_heads*dim + order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2 + + for name, tensor in hf_state_dict.items(): + # Map parameter names to ones used in megatron. + new_name = "" + new_tensor = tensor + + # This is used for chunking some tensors to target tensor parallel size. + chunk_dim = None + + if "embeddings.class_embedding" in name: + new_name = "class_token" + elif "embeddings.patch_embedding.weight" in name: + new_name = "conv1.weight" + elif "embeddings.patch_embedding.bias" in name: + new_name = "conv1.bias" + elif "embeddings.position_embedding" in name: + new_name = "position_embeddings.weight" + new_tensor = new_tensor.squeeze(0) + elif "encoder.layers" in name: + layer_idx = name.split(".")[2] + + base = f"decoder.layers.{layer_idx}" + + head_dim = 128 + + if tensor_parallel_size == 1: + num_padded_heads = 25 + elif tensor_parallel_size == 8: + # Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism. + # So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model. + num_padded_heads = 32 + else: + raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size) + + if "ls1" in name: + new_name = f"{base}.ls1" + elif "ls2" in name: + new_name = f"{base}.ls2" + elif "attn.qkv.weight" in name: + new_name = f"{base}.self_attention.linear_qkv.weight" + num_tensors = 3 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:new_tensor.shape[0], :] = new_tensor[order] + new_tensor = padded_tensor + chunk_dim = 0 + elif "attn.q_norm.weight" in name: + new_name = f"{base}.self_attention.q_layernorm.weight" + num_tensors = 1 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:new_tensor.shape[0]] = new_tensor + new_tensor = padded_tensor + chunk_dim = 0 + elif "attn.k_norm.weight" in name: + new_name = f"{base}.self_attention.k_layernorm.weight" + num_tensors = 1 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:new_tensor.shape[0]] = new_tensor + new_tensor = padded_tensor + chunk_dim = 0 + elif "attn.proj.weight" in name: + new_name = f"{base}.self_attention.linear_proj.weight" + num_tensors = 1 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:, :new_tensor.shape[-1]] = new_tensor + new_tensor = padded_tensor + chunk_dim = 1 + elif "attn.proj.bias" in name: + new_name = f"{base}.self_attention.linear_proj.bias" + elif "mlp.fc1.weight" in name: + new_name = f"{base}.mlp.linear_fc1.weight" + chunk_dim = 0 + elif "mlp.fc1.bias" in name: + new_name = f"{base}.mlp.linear_fc1.bias" + chunk_dim = 0 + elif "mlp.fc2.weight" in name: + new_name = f"{base}.mlp.linear_fc2.weight" + chunk_dim = 1 + elif "mlp.fc2.bias" in name: + new_name = f"{base}.mlp.linear_fc2.bias" + elif "norm1" in name: + new_name = f"{base}.input_layernorm.weight" + elif "norm2" in name: + new_name = f"{base}.pre_mlp_layernorm.weight" + else: + raise RuntimeError("unexpected transformer layer name", name) + else: + raise RuntimeError("unexpected layer name", name) + + assert new_name != "", f"unexpected layer name {name}" + + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. + extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) + if use_te and is_extra_state_layer: + layer = new_name.split(".")[-2] + if layer in extra_state_layers: + extra_state_name = ( + new_name[: new_name.rfind(".") + 1] + "_extra_state" + ) # Replace the weight name. + for i in range(tensor_parallel_size): + new_state_dicts[i]["model"][extra_state_name] = None + + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + for i in range(tensor_parallel_size): + output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}") + os.makedirs(output_dir_tp, exist_ok=True) + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") + torch.save(new_state_dicts[i], output_path_tp) + print("saved file", output_path_tp) + + print("done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter") + parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace") + parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.") + parser.add_argument("--use-te", action="store_true", default=True) + parser.add_argument("--tensor-parallel-size", type=int, required=True) + + args = parser.parse_args() + + convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te) diff --git a/examples/multimodal/model_converter/radio_converter.py b/examples/multimodal/model_converter/radio_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..e681e3db6be5c12b5c577757e74be01fb7946279 --- /dev/null +++ b/examples/multimodal/model_converter/radio_converter.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import os + +import torch + +def convert(output_path, tensor_parallel_size, use_te, version): + device = "cuda" + + model = torch.hub.load('NVlabs/RADIO', 'radio_model', version=version, progress=True) + + state_dict = model.state_dict() + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] + + # Indices from mapping pytorch multihead attention to megatron. + kv_channels = 80 + hidden_dim = 1280 + num_heads = 16 + indices = [] + for i in range(num_heads): + lb = i * kv_channels + ub = (i + 1) * kv_channels + indices.append(torch.arange(lb, ub, dtype=torch.int)) + indices.append(torch.arange(hidden_dim + lb, hidden_dim + ub, dtype=torch.int)) + indices.append(torch.arange(2 * hidden_dim + lb, 2 * hidden_dim + ub, dtype=torch.int)) + + indices = torch.cat(indices) + + for name, tensor in state_dict.items(): + # Map parameter names to ones used in megatron. + new_name = "" + new_tensor = tensor + if new_tensor.dtype == torch.float16: + new_tensor = new_tensor.to(torch.float32) + + # This is used for chunking some tensors to target tensor parallel size. + chunk_dim = None + + if "summary_idxs" in name: + continue + elif "patch_generator" in name: + if "embedder" in name: + new_name = "embedder.weight" + chunk_dim = 0 + elif "cls_token" in name: + new_name = "class_token" + elif "pos_embed" in name: + new_name = "position_embeddings" + elif "input_conditioner" in name: + continue + elif "blocks" in name: + layer_idx = name.split(".")[2] + base = f"decoder.layers.{layer_idx}" + + if "attn.qkv.weight" in name: + new_name = f"{base}.self_attention.linear_qkv.weight" + new_tensor = new_tensor[indices] + chunk_dim = 0 + elif "attn.qkv.bias" in name: + new_name = f"{base}.self_attention.linear_qkv.bias" + new_tensor = new_tensor[indices] + chunk_dim = 0 + elif "attn.proj.weight" in name: + new_name = f"{base}.self_attention.linear_proj.weight" + chunk_dim = 1 + elif "attn.proj.bias" in name: + new_name = f"{base}.self_attention.linear_proj.bias" + elif "norm1.weight" in name: + new_name = f"{base}.input_layernorm.weight" + if use_te: + new_name = f"{base}.self_attention.linear_qkv.layer_norm_weight" + elif "norm1.bias" in name: + new_name = f"{base}.input_layernorm.bias" + if use_te: + new_name = f"{base}.self_attention.linear_qkv.layer_norm_bias" + elif "mlp.fc1.weight" in name: + new_name = f"{base}.mlp.linear_fc1.weight" + chunk_dim = 0 + elif "mlp.fc1.bias" in name: + new_name = f"{base}.mlp.linear_fc1.bias" + chunk_dim = 0 + elif "mlp.fc2.weight" in name: + new_name = f"{base}.mlp.linear_fc2.weight" + chunk_dim = 1 + elif "mlp.fc2.bias" in name: + new_name = f"{base}.mlp.linear_fc2.bias" + elif "norm2.weight" in name: + new_name = f"{base}.pre_mlp_layernorm.weight" + if use_te: + new_name = f"{base}.mlp.linear_fc1.layer_norm_weight" + elif "norm2.bias" in name: + new_name = f"{base}.pre_mlp_layernorm.bias" + if use_te: + new_name = f"{base}.mlp.linear_fc1.layer_norm_bias" + + assert new_name != "", f"unexpected layer name {name}" + + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + # chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage. + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. + extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) + if use_te and is_extra_state_layer: + layer = new_name.split(".")[-2] + if layer in extra_state_layers: + extra_state_name = ( + new_name[: new_name.rfind(".") + 1] + "_extra_state" + ) # Replace the weight name. + new_state_dicts[i]["model"][extra_state_name] = None + + for i in range(tensor_parallel_size): + output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}") + os.makedirs(output_dir_tp) + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") + torch.save(new_state_dicts[i], output_path_tp) + with open(os.path.join(output_path, "latest_checkpointed_iteration.txt"), "w") as f: + f.write("1") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=""" +Convert RADIO weights to megatron format. + + +Example usage: +python radio_converter.py --output /some/output/folder --tensor-parallel-size 4 +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--output", type=str, required=True, help="output directory for megatron state dict file(s)" + ) + parser.add_argument( + "--tensor-parallel-size", type=int, default=1, help="model tensor parallel size" + ) + parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine") + parser.add_argument("--version", type=str, default="radio_v2.5-h", help="Version of radio to load for conversion") + + args = parser.parse_args() + + convert(args.output, args.tensor_parallel_size, args.use_te, args.version) + + print("done.") diff --git a/examples/multimodal/multimodal_args.py b/examples/multimodal/multimodal_args.py index eb56118e71613ea7fae6f81ff44f2969f26b4533..22fadc9c5e7dc480f57a6c3ccf8877f755576039 100644 --- a/examples/multimodal/multimodal_args.py +++ b/examples/multimodal/multimodal_args.py @@ -1,79 +1,89 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN - - -def add_multimodal_extra_args(parser): - """Extra arguments.""" - group = parser.add_argument_group(title='multimodal arguments') - group.add_argument('--dataset-config', type=str, default=None) - group.add_argument("--prompt-path", type=str, default=None) - group.add_argument('--freeze-LM', action='store_true', default=False) - group.add_argument('--freeze-ViT', action='store_true', default=False) - group.add_argument('--language-model-type', type=str, required=True) - group.add_argument('--vision-model-type', type=str, default="clip") - group.add_argument("--disable-vision-class-token", action="store_true", default=False) - group.add_argument( - "--allow-missing-vision-projection-checkpoint", action="store_true", default=False - ) - group.add_argument("--use-te", action="store_true", default=False) - group.add_argument( - "--dataloader-save", type=str, default=None, help="Energon dataloader state save path" - ) - group.add_argument( - "--use-tiling", action="store_true", default=False, help="Use input image tiling" - ) - group.add_argument("--max-num-tiles", type=int, default=1, help="Maximum number of image tiles") - group.add_argument( - "--use-thumbnail", action="store_true", default=False, help="Add image thumbnail as a tile" - ) - group.add_argument( - "--dataloader-seq-length", - type=int, - help="Make dataloader to produce sequences of specific length.", - ) - group.add_argument( - "--num-frames", - type=int, - default=1, - help="Number of frames to regularly sample from the video as input to the model.", - ) - group.add_argument( - "--online-evaluation-config", type=str, help="Config file for online evaluation." - ) - group.add_argument( - "--special-tokens", - nargs="*", - default=[IMAGE_TOKEN], - help="Special tokens used in the multimodal model", - ) - group.add_argument( - "--tokenizer-prompt-format", - type=str, - choices=["mistral", "llama3", "chatml", "nvlm-yi-34b", "qwen2p0", "qwen2p5"], - required=True, - help="Prompt format to use with the tokenizer.", - ) - group.add_argument("--pixel-shuffle", action="store_true", default=False) - group.add_argument( - "--image-tag-type", - type=str, - choices=["nvlm", "internvl", ""], - default="", # Default: Image tag not used. - help="Surround image tokens with tags.", - ) - group.add_argument("--use-tile-tags", action="store_true", default=False, help="Use tile tags") - group.add_argument( - "--packing-buffer-size", - type=int, - default=None, # Packing is disabled by default. - help="Enable sample packing by setting the buffer size to > 0", - ) - group.add_argument( - "--packing-seq-length", type=int, default=0, help="Packing sequence length. Must be > 0 if using packing." - ) - group.add_argument( - "--recompute-vision", action="store_true", default=False, help="Enable activation checkpointing in the vision model" - ) - - - return parser +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN + + +def add_multimodal_extra_args(parser): + """Extra arguments.""" + group = parser.add_argument_group(title='multimodal arguments') + group.add_argument('--dataset-config', type=str, default=None) + group.add_argument("--prompt-path", type=str, default=None) + group.add_argument('--freeze-LM', action='store_true', default=False) + group.add_argument('--freeze-ViT', action='store_true', default=False) + group.add_argument('--language-model-type', type=str, required=True) + group.add_argument('--language-huggingface-model-name-or-path', type=str) + group.add_argument('--vision-model-type', type=str, default="clip") + group.add_argument('--vision-huggingface-model-name-or-path', type=str) + group.add_argument("--disable-vision-class-token", action="store_true", default=False) + group.add_argument( + "--allow-missing-vision-projection-checkpoint", action="store_true", default=False + ) + group.add_argument("--use-te", action="store_true", default=False) + group.add_argument( + "--dataloader-save", type=str, default=None, help="Energon dataloader state save path" + ) + group.add_argument( + "--use-tiling", action="store_true", default=False, help="Use input image tiling" + ) + group.add_argument("--max-num-tiles", type=int, default=1, help="Maximum number of image tiles") + group.add_argument( + "--use-thumbnail", action="store_true", default=False, help="Add image thumbnail as a tile" + ) + group.add_argument( + "--dataloader-seq-length", + type=int, + help="Make dataloader to produce sequences of specific length.", + ) + group.add_argument( + "--num-frames", + type=int, + default=1, + help="Number of frames to regularly sample from the video as input to the model.", + ) + group.add_argument( + "--online-evaluation-config", type=str, help="Config file for online evaluation." + ) + group.add_argument( + "--special-tokens", + nargs="*", + default=[IMAGE_TOKEN], + help="Special tokens used in the multimodal model", + ) + group.add_argument( + "--tokenizer-prompt-format", + type=str, + choices=["mistral", "llama3", "llama3p1", "chatml", "nvlm-yi-34b", "qwen2p0", "qwen2p5"], + required=True, + help="Prompt format to use with the tokenizer.", + ) + group.add_argument("--pixel-shuffle", action="store_true", default=False) + group.add_argument( + "--image-tag-type", + type=str, + choices=["nvlm", "internvl", ""], + default="", # Default: Image tag not used. + help="Surround image tokens with tags.", + ) + group.add_argument("--use-tile-tags", action="store_true", default=False, help="Use tile tags") + group.add_argument( + "--packing-buffer-size", + type=int, + default=None, # Packing is disabled by default. + help="Enable sample packing by setting the buffer size to > 0", + ) + group.add_argument( + "--packing-seq-length", type=int, default=0, help="Packing sequence length. Must be > 0 if using packing." + ) + group.add_argument( + "--recompute-vision", action="store_true", default=False, help="Enable activation checkpointing in the vision model" + ) + group.add_argument( + "--use-loss-scaling", action="store_true", default=False, help="Scale loss based on conversation turn length (in tokens)." + ) + group.add_argument( + "--use-area-weighted-aspect-ratio", action="store_true", default=False, + help=( + "When --use-tiling is True, find the aspect ratio to use based on the original ", + "image aspect ratio and the area covered by the tiles.") + ) + + return parser diff --git a/examples/multimodal/nvlm/internvit.py b/examples/multimodal/nvlm/internvit.py index cd116ffb76c13634fdfef12994df497122340653..ac560ed70ba6f37998c014a06479eda09df03b46 100644 --- a/examples/multimodal/nvlm/internvit.py +++ b/examples/multimodal/nvlm/internvit.py @@ -1,273 +1,279 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -"""" -NOTE: NVLM uses InternViT with tensor parallel (TP) size = 8. -Since InternViT has 25 attention heads and Megatron currently requires the number of attention heads -to be divisible by the TP size, we add 7 dummy zero attention heads to have 32 attention heads. - -This workaround requires some changes to how we compute RMSNorm, Attention etc. - -Additionally, InternViT introduces some unique features like Layer Scaling. - -Those code changes are gathered here. -""" -from functools import partial -from typing import Dict - -import torch - -from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.extensions.transformer_engine import ( - TEColumnParallelLinear, - TEDotProductAttention, - TERowParallelLinear, -) -from megatron.core.parallel_state import ( - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules -from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint - - -class InternViTRMSNorm(MegatronModule): - - def __init__( - self, - config, - hidden_size: int, - eps: float = 1e-6, - sequence_parallel: bool = False, - compute_var: bool = False, - ): - """Custom RMSNorm for InternViT. - - Args: - config (TransformerConfig): Config. - hidden_size (int): Input hidden size. - eps (float): epsilon to use for the norm, default to 1e-6 - sequence_parallel (bool): Set to true if sequence parallelism is being used, - this marks the weights as needing to be allreduced. - compute_var (bool): Indicator to compute statistic manually. - """ - super().__init__(config=config) - self.config = config - self.eps = eps - self.weight = torch.nn.Parameter(torch.ones(hidden_size)) - self._compute_var = compute_var - - assert not sequence_parallel, "Sequence parallelism is not supported with InternViT." - - setattr(self.weight, 'sequence_parallel', sequence_parallel) - - def _norm(self, x, var): - if var is None: - var = x.pow(2).mean(-1, keepdim=True) - - return x * torch.rsqrt(var + self.eps) - - def forward(self, x): - """Run RMSNorm with an option to compute custom statistic.""" - var = None - if self._compute_var: - unpadded_hidden_size = self.config.hidden_size # 3200 - max_dim = x.shape[-1] # 128 - - x = x.reshape(x.size(0), x.size(1), -1) - var = self._gather_var(x.float().pow(2), max_dim) / unpadded_hidden_size - - output = self._norm(x.float(), var).type_as(x) - output = output * self.weight - - if self._compute_var: - output = output.reshape(output.size(0), output.size(1), -1, max_dim) - - return output - - def _gather_var(self, input_, max_dim, valid_ranks=6): - """Compute statistic across the non-dummy heads.""" - world_size = get_tensor_model_parallel_world_size() - assert world_size == 8, "tested only with TP=8" - - # Size and dimension. - last_dim = input_.dim() - 1 - rank = get_tensor_model_parallel_rank() - - if rank < valid_ranks: # Ranks 0-5 have 24 non-dummy attention heads. - var = input_.sum(-1, keepdim=True) - elif rank == valid_ranks: # Rank 6 has 1 non-dummy attention head. - var = input_[..., :max_dim].sum(-1, keepdim=True) - else: - var = input_.sum(-1, keepdim=True) * 0.0 # Zero-out the dummy heads. - - tensor_list = [torch.empty_like(var) for _ in range(world_size)] - tensor_list[rank] = var - torch.distributed.all_gather(tensor_list, var, group=get_tensor_model_parallel_group()) - - output = torch.cat(tensor_list, dim=last_dim).contiguous() - - return output.sum(-1, keepdim=True) - - def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata={}): - - # in InternVitSelfAttention the q_layernorm and k_layernorm weights - # are tensor-parallel so must be converted to sharded tensors - if 'q_layernorm' in prefix or 'k_layernorm' in prefix: - state_dict = self.state_dict(prefix='', keep_vars=True) - return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {'weight': 0}, sharded_offsets - ) - else: - return super().sharded_state_dict(prefix, sharded_offsets, metadata) - - -def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: - # Dense MLP w/ or w/o TE modules. - return ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, - linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, - ), - ) - - -# Handle InternViT's layer scaling. -def _bias_dropout_add_func_internvit(ls, x_with_bias, residual, prob, training): - x, bias = x_with_bias # unpack - residual = residual if residual.dtype == x.dtype else residual.to(x.dtype) - if bias is not None: - x = x + bias - out = torch.nn.functional.dropout(x, p=prob, training=training) - out = residual + out * ls - return out - else: - out = torch.nn.functional.dropout(x, p=prob, training=training) - out = residual + out * ls - return out - - -def bias_dropout_add_unfused_internvit(ls, training): - """Bias-dropout-add as in Megatron but with added LayerScaling handling.""" - - def _bias_dropout_add(x_with_bias, residual, prob): - return _bias_dropout_add_func_internvit(ls, x_with_bias, residual, prob, training) - - return _bias_dropout_add - - -def get_bias_dropout_add_internvit(ls, training, fused): - """Bias-dropout-add as in Megatron but with added LayerScaling handling.""" - assert not fused, "Fused bias-dropout-add not implemented for InternViT." - return bias_dropout_add_unfused_internvit(ls, training) - - -# Add InternViT specialties to our default TransformerLayer. -class InternViTTransformerLayer(TransformerLayer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.ls1 = torch.nn.Parameter(torch.ones(self.config.hidden_size)) - self.ls2 = torch.nn.Parameter(torch.ones(self.config.hidden_size)) - - self.self_attn_bda = partial(self.self_attn_bda, self.ls1) - self.mlp_bda = partial(self.mlp_bda, self.ls2) - - -# Override a few things that are special in InternViT and not supported by the SelfAttention class. -class InternViTSelfAttention(SelfAttention): - def __init__( - self, config: TransformerConfig, submodules: SelfAttentionSubmodules, *args, **kwargs - ): - super().__init__(config=config, submodules=submodules, *args, **kwargs) - - # Need to override linear_qkv, q_layernorm and k_layernorm. - qkv_bias = False - - self.linear_qkv = build_module( - submodules.linear_qkv, - self.config.hidden_size, - self.query_projection_size + 2 * self.kv_projection_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=qkv_bias, - skip_bias_add=False, - is_expert=False, - tp_comm_buffer_name='qkv', - ) - - qk_layernorm_hidden_size = ( - self.hidden_size_per_attention_head * self.num_attention_heads_per_partition - ) # 512 for internvit - - self.q_layernorm = build_module( - submodules.q_layernorm, - hidden_size=qk_layernorm_hidden_size, - config=self.config, - eps=self.config.layernorm_epsilon, - compute_var=True, - ) - - self.k_layernorm = build_module( - submodules.k_layernorm, - hidden_size=qk_layernorm_hidden_size, - config=self.config, - eps=self.config.layernorm_epsilon, - compute_var=True, - ) - - -class InternViTTEDotProductAttention(TEDotProductAttention): - """Adjusted Attention for InternViT""" - - def forward(self, *args, **kwargs): - """Regular TEDotProductAttention + zero-out dummy attention heads.""" - out = super().forward(*args, **kwargs) - - # This makes sure the dummy attention heads are zeroed out. - mask = torch.ones_like(out, dtype=out.dtype, device=out.device) - rank = get_tensor_model_parallel_rank() - max_dim = out.shape[-1] # 128 - valid_ranks = 6 - - if rank == valid_ranks: - mask[..., max_dim:] *= 0.0 - elif rank > valid_ranks: - mask *= 0.0 - out *= mask - - return out - - -def get_internvit_layer_spec(use_te) -> ModuleSpec: - mlp = get_mlp_module_spec(use_te) # no norm - - return ModuleSpec( - module=InternViTTransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=InternViTRMSNorm, - self_attention=ModuleSpec( - module=InternViTSelfAttention, - params={"attn_mask_type": AttnMaskType.no_mask}, - submodules=SelfAttentionSubmodules( - linear_qkv=TEColumnParallelLinear if use_te else ColumnParallelLinear, - core_attention=TEDotProductAttention if use_te else DotProductAttention, - linear_proj=TERowParallelLinear if use_te else RowParallelLinear, - q_layernorm=InternViTRMSNorm, - k_layernorm=InternViTRMSNorm, - ), - ), - self_attn_bda=get_bias_dropout_add_internvit, - pre_mlp_layernorm=InternViTRMSNorm, - mlp=mlp, - mlp_bda=get_bias_dropout_add_internvit, - ), - ) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""" +NOTE: NVLM uses InternViT with tensor parallel (TP) size = 8. +Since InternViT has 25 attention heads and Megatron currently requires the number of attention heads +to be divisible by the TP size, we add 7 dummy zero attention heads to have 32 attention heads. + +This workaround requires some changes to how we compute RMSNorm, Attention etc. + +Additionally, InternViT introduces some unique features like Layer Scaling. + +Those code changes are gathered here. +""" +from functools import partial + +import torch + +from megatron.core.utils import divide +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TERowParallelLinear, +) +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint + + +class InternViTRMSNorm(MegatronModule): + + def __init__( + self, + config, + hidden_size: int, + eps: float = 1e-6, + sequence_parallel: bool = False, + compute_var: bool = False, + ): + """Custom RMSNorm for InternViT. + + Args: + config (TransformerConfig): Config. + hidden_size (int): Input hidden size. + eps (float): epsilon to use for the norm, default to 1e-6 + sequence_parallel (bool): Set to true if sequence parallelism is being used, + this marks the weights as needing to be allreduced. + compute_var (bool): Indicator to compute statistic manually. + """ + super().__init__(config=config) + self.config = config + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self._compute_var = compute_var + + assert not sequence_parallel, "Sequence parallelism is not supported with InternViT." + + setattr(self.weight, 'sequence_parallel', sequence_parallel) + + def _norm(self, x, var): + if var is None: + var = x.pow(2).mean(-1, keepdim=True) + + return x * torch.rsqrt(var + self.eps) + + def forward(self, x): + """Run RMSNorm with an option to compute custom statistic.""" + var = None + if self._compute_var: + unpadded_hidden_size = self.config.hidden_size # 3200 + max_dim = x.shape[-1] # 128 + + x = x.reshape(x.size(0), x.size(1), -1) + var = self._gather_var(x.float().pow(2), max_dim) / unpadded_hidden_size + + output = self._norm(x.float(), var).type_as(x) + output = output * self.weight + + if self._compute_var: + output = output.reshape(output.size(0), output.size(1), -1, max_dim) + + return output + + def _gather_var(self, input_, max_dim): + """Compute statistic across the non-dummy heads.""" + world_size = get_tensor_model_parallel_world_size() + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = get_tensor_model_parallel_rank() + + num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) + valid_ranks = 24 // num_attention_heads_per_partition + + residual_heads = 25 % num_attention_heads_per_partition + if residual_heads == 0: + residual_heads = num_attention_heads_per_partition + max_dim = max_dim * residual_heads + + if rank < valid_ranks: # Ranks without any dummy attention heads. + var = input_.sum(-1, keepdim=True) + elif rank == valid_ranks: # The only rank which may contain 'residual_heads' dummy attention heads. + var = input_[..., :max_dim].sum(-1, keepdim=True) + else: + var = input_.sum(-1, keepdim=True) * 0.0 # All heads in these ranks are dummy heads: Zero-out. + + tensor_list = [torch.empty_like(var) for _ in range(world_size)] + tensor_list[rank] = var + torch.distributed.all_gather(tensor_list, var, group=get_tensor_model_parallel_group()) + + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output.sum(-1, keepdim=True) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata={}): + + # in InternVitSelfAttention the q_layernorm and k_layernorm weights + # are tensor-parallel so must be converted to sharded tensors + if 'q_layernorm' in prefix or 'k_layernorm' in prefix: + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 0}, sharded_offsets + ) + else: + return super().sharded_state_dict(prefix, sharded_offsets, metadata) + + +def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + + +# Handle InternViT's layer scaling. +def _bias_dropout_add_func_internvit(ls, x_with_bias, residual, prob, training): + x, bias = x_with_bias # unpack + residual = residual if residual.dtype == x.dtype else residual.to(x.dtype) + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out * ls + return out + else: + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out * ls + return out + + +def bias_dropout_add_unfused_internvit(ls, training): + """Bias-dropout-add as in Megatron but with added LayerScaling handling.""" + + def _bias_dropout_add(x_with_bias, residual, prob): + return _bias_dropout_add_func_internvit(ls, x_with_bias, residual, prob, training) + + return _bias_dropout_add + + +def get_bias_dropout_add_internvit(ls, training, fused): + """Bias-dropout-add as in Megatron but with added LayerScaling handling.""" + assert not fused, "Fused bias-dropout-add not implemented for InternViT." + return bias_dropout_add_unfused_internvit(ls, training) + + +# Add InternViT specialties to our default TransformerLayer. +class InternViTTransformerLayer(TransformerLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ls1 = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + self.ls2 = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + + self.self_attn_bda = partial(self.self_attn_bda, self.ls1) + self.mlp_bda = partial(self.mlp_bda, self.ls2) + + +# Override a few things that are special in InternViT and not supported by the SelfAttention class. +class InternViTSelfAttention(SelfAttention): + def __init__( + self, config: TransformerConfig, submodules: SelfAttentionSubmodules, *args, **kwargs + ): + super().__init__(config=config, submodules=submodules, *args, **kwargs) + + # Need to override linear_qkv, q_layernorm and k_layernorm. + qkv_bias = False + + self.linear_qkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + ) + + qk_layernorm_hidden_size = ( + self.hidden_size_per_attention_head * self.num_attention_heads_per_partition + ) # 512 for internvit + + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=qk_layernorm_hidden_size, + config=self.config, + eps=self.config.layernorm_epsilon, + compute_var=True, + ) + + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=qk_layernorm_hidden_size, + config=self.config, + eps=self.config.layernorm_epsilon, + compute_var=True, + ) + + +class InternViTTEDotProductAttention(TEDotProductAttention): + """Adjusted Attention for InternViT""" + + def forward(self, *args, **kwargs): + """Regular TEDotProductAttention + zero-out dummy attention heads.""" + out = super().forward(*args, **kwargs) + + # This makes sure the dummy attention heads are zeroed out. + mask = torch.ones_like(out, dtype=out.dtype, device=out.device) + rank = get_tensor_model_parallel_rank() + max_dim = out.shape[-1] # 128 + valid_ranks = 6 + + if rank == valid_ranks: + mask[..., max_dim:] *= 0.0 + elif rank > valid_ranks: + mask *= 0.0 + out *= mask + + return out + + +def get_internvit_layer_spec(use_te) -> ModuleSpec: + mlp = get_mlp_module_spec(use_te) # no norm + + return ModuleSpec( + module=InternViTTransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=InternViTRMSNorm, + self_attention=ModuleSpec( + module=InternViTSelfAttention, + params={"attn_mask_type": AttnMaskType.no_mask}, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear if use_te else ColumnParallelLinear, + core_attention=TEDotProductAttention if use_te else DotProductAttention, + linear_proj=TERowParallelLinear if use_te else RowParallelLinear, + q_layernorm=InternViTRMSNorm, + k_layernorm=InternViTRMSNorm, + ), + ), + self_attn_bda=get_bias_dropout_add_internvit, + pre_mlp_layernorm=InternViTRMSNorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add_internvit, + ), + ) diff --git a/examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh b/examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh old mode 100644 new mode 100755 diff --git a/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh b/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh old mode 100644 new mode 100755 index 00f94352774518b1c8dc478c98808a16a3398b75..a9ba430f1a42f377ea018afcb6f552dc204eee38 --- a/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh +++ b/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh @@ -1,154 +1,155 @@ -#!/bin/bash - -# Your SBATCH commands here if using SLURM. - -# Please launch this script from megatron-lm root. - -# Train a multimodal model. - -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export TOKENIZERS_PARALLELISM="false" - -DEBUG=0 - -if [[ $BATCH -eq 0 ]]; then - DATETIME=`date +'%y-%m-%d-%H-%M-%S'` - MODEL_NAME="mcore-nous-yi34b-internvit-mlp-${DATETIME}" -else - MODEL_NAME="mcore-nous-yi34b-internvit-mlp" -fi - -WORKSPACE="" -SOURCE=`pwd` -OUTPUT_BASE="${WORKSPACE}/output" -OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" - -FINETUNE_DIR=${OUTPUT}/checkpoints -LOGS_DIR="${OUTPUT}/logs" -TENSORBOARD_DIR="${OUTPUT}/tensorboard" - -LOAD_NAME="combined-yi-34b-internvit-tp8-mcore" -CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}" - -DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/pretrain_blend.yaml" - - -if [[ $DEBUG -eq 1 ]]; then - MBZ=1 - BZ=1 - NW=0 - LI=1 - AD=0.0 - HD=0.0 - EXTRA_ARGS="" - ALLOW_NONDETERMINISTIC=1 -else - MBZ=1 - BZ=2048 - NW=8 - LI=5 - AD=0.1 - HD=0.1 - EXTRA_ARGS="" - ALLOW_NONDETERMINISTIC=1 -fi - -SEQ_LEN=256 # Image embeddings sequence length. -DECODER_SEQ_LEN=512 # Language model sequence length. -MAX_POS_EMBED=512 - - -OPTIONS=" \ - --swiglu \ - --use-distributed-optimizer \ - --num-workers ${NW} \ - --num-layers 60 \ - --hidden-size 7168 \ - --normalization RMSNorm \ - --num-attention-heads 56 \ - --exit-duration-in-mins 230 \ - --group-query-attention \ - --num-query-groups 8 \ - --ffn-hidden-size 20480 \ - --seq-length ${SEQ_LEN} \ - --decoder-seq-length ${DECODER_SEQ_LEN} \ - --max-position-embeddings ${MAX_POS_EMBED} \ - --tokenizer-type MultimodalTokenizer \ - --tokenizer-model NousResearch/Nous-Hermes-2-Yi-34B \ - --tokenizer-prompt-format nvlm-yi-34b \ - --vocab-size 64000 \ - --make-vocab-size-divisible-by 1 \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 5000000 \ - --disable-bias-linear \ - --tensor-model-parallel-size 8 \ - --language-model-type yi-34b \ - --vision-model-type internvit \ - --micro-batch-size ${MBZ} \ - --global-batch-size ${BZ} \ - --train-samples 122880000 \ - --lr-decay-samples 25600000 \ - --lr-warmup-samples 83200 \ - --lr 1e-4 \ - --min-lr 2.5e-5 \ - --lr-decay-style cosine \ - --clip-grad 10.0 \ - --weight-decay 0.1 \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --init-method-std 0.014 \ - --attention-dropout ${AD} \ - --hidden-dropout ${HD} \ - --eod-mask-loss \ - --bf16 \ - --tensorboard-dir=${TENSORBOARD_DIR} \ - --freeze-LM \ - --freeze-ViT \ - --img-h 448 \ - --img-w 448 \ - --patch-dim 14 \ - --data-path ${DATA_TRAIN} \ - --dataloader-type external \ - --split 100,0,0 \ - --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ - --log-interval ${LI} \ - --save-interval 2000 \ - --eval-interval 500 \ - --eval-iters 10 \ - --log-params-norm \ - --log-num-zeros-in-grad \ - ${EXTRA_ARGS} \ - --save ${FINETUNE_DIR} \ - --load ${FINETUNE_DIR} \ - --dataloader-save ${FINETUNE_DIR}/dataloader \ - --pretrained-checkpoint ${CHECKPOINT_DIR} \ - --allow-missing-vision-projection-checkpoint \ - --disable-vision-class-token \ - --use-te \ - --use-checkpoint-args \ - --ckpt-format torch \ - --pixel-shuffle \ - --image-tag-type nvlm - " - -export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} -export NVTE_APPLY_QK_LAYER_SCALING=0 - -# Interactive or batch mode -if [[ $BATCH -eq 0 ]]; then - torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} -else - run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" - - DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` - - srun -l --verbose \ - --container-image \ - --container-mounts "" \ - --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ - sh -c "${run_cmd}" - - set +x -fi +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export TOKENIZERS_PARALLELISM="false" + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-nous-yi34b-internvit-mlp-${DATETIME}" +else + MODEL_NAME="mcore-nous-yi34b-internvit-mlp" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +LOAD_NAME="combined-yi-34b-internvit-tp8-mcore" +CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/pretrain_blend.yaml" + + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + LI=1 + AD=0.0 + HD=0.0 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +else + MBZ=1 + BZ=2048 + NW=8 + LI=5 + AD=0.1 + HD=0.1 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +fi + +SEQ_LEN=256 # Image embeddings sequence length. +DECODER_SEQ_LEN=512 # Language model sequence length. +MAX_POS_EMBED=512 + + +OPTIONS=" \ + --swiglu \ + --use-distributed-optimizer \ + --num-workers ${NW} \ + --num-layers 60 \ + --hidden-size 7168 \ + --normalization RMSNorm \ + --num-attention-heads 56 \ + --exit-duration-in-mins 230 \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 20480 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model NousResearch/Nous-Hermes-2-Yi-34B \ + --tokenizer-prompt-format nvlm-yi-34b \ + --vocab-size 64000 \ + --make-vocab-size-divisible-by 1 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 5000000 \ + --disable-bias-linear \ + --tensor-model-parallel-size 8 \ + --language-model-type yi-34b \ + --vision-model-type internvit \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --train-samples 122880000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --lr 1e-4 \ + --min-lr 2.5e-5 \ + --lr-decay-style cosine \ + --clip-grad 10.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --eod-mask-loss \ + --bf16 \ + --tensorboard-dir=${TENSORBOARD_DIR} \ + --freeze-LM \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --data-path ${DATA_TRAIN} \ + --dataloader-type external \ + --split 100,0,0 \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --log-interval ${LI} \ + --save-interval 2000 \ + --eval-interval 500 \ + --eval-iters 10 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + ${EXTRA_ARGS} \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --allow-missing-vision-projection-checkpoint \ + --disable-vision-class-token \ + --use-te \ + --use-checkpoint-args \ + --ckpt-format torch \ + --pixel-shuffle \ + --image-tag-type nvlm + " + +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} +export NVTE_APPLY_QK_LAYER_SCALING=0 + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh b/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh old mode 100644 new mode 100755 index e3b001c7aaee4544fde590ee41a8ae0d01497d36..165682ed6c13cd91c1ba1a6454feac99b69d51d9 --- a/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh +++ b/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh @@ -1,141 +1,141 @@ -#!/bin/bash - -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NVTE_APPLY_QK_LAYER_SCALING=0 -export TOKENIZERS_PARALLELISM="false" - -INPUT_IMAGE_PATH="placeholder" -GROUNDTRUTH_PATH="placeholder" - -USE_TILING=0 -USE_PIXEL_SHUFFLE_ONLY=0 - -while [[ $# -gt 0 ]]; do - case $1 in - --input-image-path) - INPUT_IMAGE_PATH="$2" - shift - shift - ;; - -o|--output-path) - OUTPUT_PATH="$2" - shift - shift - ;; - -m|--model-path) - MODEL_PATH="$2" - shift - shift - ;; - --task) - TASK="$2" - shift - shift - ;; - -g|--gt-path) - GROUNDTRUTH_PATH="$2" - shift - shift - ;; - --use-tiling) - USE_TILING=1 - shift - shift - ;; - --use-pixel-shuffle-only) - USE_PIXEL_SHUFFLE_ONLY=1 - shift - shift - ;; - -*|--*) - echo "Invalid option $1" - exit 1 - ;; - esac -done - -# Please modify these as needed. -NUM_PARTITIONS=0 -START=0 -END=0 - -SEQ_LEN=1024 # Image embeddings sequence length. -DECODER_SEQ_LEN=8192 # Language model sequence length. -MAX_POS_EMBED=8192 - -# Additional arguments. -EXTRA_ARGS="" - -if [[ $USE_TILING -eq 1 ]]; then - EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 6 --use-thumbnail --use-tile-tags" - SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). -fi - -if [[ $USE_PIXEL_SHUFFLE_ONLY -eq 1 ]]; then - EXTRA_ARGS+=" --pixel-shuffle" - SEQ_LEN=256 -fi - -for PARTITION_ID in $( eval echo {$START..$END} ) -do - torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ - --attention-softmax-in-fp32 \ - --no-masked-softmax-fusion \ - --swiglu \ - --num-layers 80 \ - --hidden-size 8192 \ - --normalization RMSNorm \ - --norm-epsilon 1e-06 \ - --num-attention-heads 64 \ - --exit-on-missing-checkpoint \ - --group-query-attention \ - --num-query-groups 8 \ - --ffn-hidden-size 29568 \ - --load ${MODEL_PATH} \ - --seq-length ${SEQ_LEN} \ - --decoder-seq-length ${DECODER_SEQ_LEN} \ - --max-position-embeddings ${MAX_POS_EMBED} \ - --tokenizer-type MultimodalTokenizer \ - --tokenizer-model Qwen/Qwen2-72B-Instruct \ - --tokenizer-prompt-format qwen2p0 \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --disable-bias-linear \ - --add-qkv-bias \ - --tensor-model-parallel-size 8 \ - --pipeline-model-parallel-size 1 \ - --language-model-type qwen2.0_72B \ - --vision-model-type internvit \ - --micro-batch-size 1 \ - --attention-dropout 0.0 \ - --hidden-dropout 0.0 \ - --bf16 \ - --freeze-LM \ - --freeze-ViT \ - --img-h 448 \ - --img-w 448 \ - --patch-dim 14 \ - --use-te \ - --transformer-impl transformer_engine \ - --use-checkpoint-args \ - --out-seq-length 16 \ - --temperature 1.0 \ - --patch-dim 14 \ - --seed 1234 \ - --top_k 1 \ - --no-load-rng \ - --no-load-optim \ - --num-partitions ${NUM_PARTITIONS} \ - --partition-id ${PARTITION_ID} \ - --output-path ${OUTPUT_PATH} \ - --gt-path ${GROUNDTRUTH_PATH} \ - --disable-vision-class-token \ - --input-image-path ${INPUT_IMAGE_PATH} \ - --gt-path ${GROUNDTRUTH_PATH} \ - ${EXTRA_ARGS} \ - --task ${TASK} \ - --image-tag-type nvlm \ - --ckpt-format torch -done +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 +export TOKENIZERS_PARALLELISM="false" + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +USE_TILING=0 +USE_PIXEL_SHUFFLE_ONLY=0 + +while [[ $# -gt 0 ]]; do + case $1 in + --input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + --task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + --use-tiling) + USE_TILING=1 + shift + shift + ;; + --use-pixel-shuffle-only) + USE_PIXEL_SHUFFLE_ONLY=1 + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + +SEQ_LEN=1024 # Image embeddings sequence length. +DECODER_SEQ_LEN=8192 # Language model sequence length. +MAX_POS_EMBED=8192 + +# Additional arguments. +EXTRA_ARGS="" + +if [[ $USE_TILING -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 6 --use-thumbnail --use-tile-tags" + SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +fi + +if [[ $USE_PIXEL_SHUFFLE_ONLY -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle" + SEQ_LEN=256 +fi + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --no-masked-softmax-fusion \ + --swiglu \ + --num-layers 80 \ + --hidden-size 8192 \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --num-attention-heads 64 \ + --exit-on-missing-checkpoint \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 29568 \ + --load ${MODEL_PATH} \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2-72B-Instruct \ + --tokenizer-prompt-format qwen2p0 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --disable-bias-linear \ + --add-qkv-bias \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --language-model-type qwen2.0_72B \ + --vision-model-type internvit \ + --micro-batch-size 1 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --bf16 \ + --freeze-LM \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --use-te \ + --transformer-impl transformer_engine \ + --use-checkpoint-args \ + --out-seq-length 16 \ + --temperature 1.0 \ + --patch-dim 14 \ + --seed 1234 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --disable-vision-class-token \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + ${EXTRA_ARGS} \ + --task ${TASK} \ + --image-tag-type nvlm \ + --ckpt-format torch +done diff --git a/examples/multimodal/nvlm/run_text_generation_qwen25_7b_internvit_video.sh b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_internvit_video.sh new file mode 100755 index 0000000000000000000000000000000000000000..df1e900ee2d5be0b7fa0ebf7359e3ffda97c63bb --- /dev/null +++ b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_internvit_video.sh @@ -0,0 +1,129 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 +export TOKENIZERS_PARALLELISM="false" + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +while [[ $# -gt 0 ]]; do + case $1 in + --input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + --input-metadata-path) + INPUT_METADATA_PATH="$2" + shift + shift + ;; + --num-frames) + NUM_FRAMES="$2" + shift + shift + ;; + -g|--groundtruth-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + --task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + +SEQ_LEN=256 +DECODER_SEQ_LEN=16384 + +EXTRA_ARGS=" --pixel-shuffle" + + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --transformer-impl transformer_engine \ + --use-te \ + --use-checkpoint-args \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --language-model-type=qwen2.5_7B \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --group-query-attention \ + --num-query-groups 4 \ + --num-layers 28 \ + --hidden-size 3584 \ + --ffn-hidden-size 18944 \ + --add-qkv-bias \ + --num-attention-heads 28 \ + --max-position-embeddings 32768 \ + --no-masked-softmax-fusion \ + --load ${MODEL_PATH} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2.5-7B-Instruct \ + --tokenizer-prompt-format qwen2p5 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --out-seq-length 128 \ + --temperature 1.0 \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --seed 153 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --task ${TASK} \ + ${EXTRA_ARGS} \ + --special-tokens "" "" "" \ + --vision-model-type internvit \ + --num-frames ${NUM_FRAMES} \ + --ckpt-format torch +done diff --git a/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh old mode 100644 new mode 100755 index 3b6221996c8294790b946f3c453d01eb71b692e7..d66640fcc170e855ba84b077902897f294b7eb91 --- a/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh +++ b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh @@ -1,111 +1,111 @@ -#!/bin/bash - -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NVTE_APPLY_QK_LAYER_SCALING=0 -export TOKENIZERS_PARALLELISM="false" - -INPUT_IMAGE_PATH="placeholder" -GROUNDTRUTH_PATH="placeholder" - -while [[ $# -gt 0 ]]; do - case $1 in - -i|--input-image-path) - INPUT_IMAGE_PATH="$2" - shift - shift - ;; - -o|--output-path) - OUTPUT_PATH="$2" - shift - shift - ;; - -m|--model-path) - MODEL_PATH="$2" - shift - shift - ;; - -t|--task) - TASK="$2" - shift - shift - ;; - -g|--gt-path) - GROUNDTRUTH_PATH="$2" - shift - shift - ;; - -*|--*) - echo "Invalid option $1" - exit 1 - ;; - esac -done - -# Please modify these as needed. -NUM_PARTITIONS=0 -START=0 -END=0 - - -SEQ_LEN=256 -DECODER_SEQ_LEN=8192 -EXTRA_ARGS=" --pixel-shuffle --use-tiling --max-num-tiles 12 --use-thumbnail" - -for PARTITION_ID in $( eval echo {$START..$END} ) -do - torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ - --attention-softmax-in-fp32 \ - --transformer-impl transformer_engine \ - --use-te \ - --use-checkpoint-args \ - --normalization RMSNorm \ - --norm-epsilon 1e-06 \ - --language-model-type=qwen2.5_7B \ - --untie-embeddings-and-output-weights \ - --disable-bias-linear \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --swiglu \ - --attention-dropout 0.0 \ - --hidden-dropout 0.0 \ - --tensor-model-parallel-size 4 \ - --pipeline-model-parallel-size 1 \ - --group-query-attention \ - --num-query-groups 4 \ - --num-layers 28 \ - --hidden-size 3584 \ - --ffn-hidden-size 18944 \ - --add-qkv-bias \ - --num-attention-heads 28 \ - --max-position-embeddings 32768 \ - --no-masked-softmax-fusion \ - --load ${MODEL_PATH} \ - --tokenizer-type MultimodalTokenizer \ - --tokenizer-model Qwen/Qwen2.5-7B-Instruct \ - --tokenizer-prompt-format qwen2p5 \ - --bf16 \ - --micro-batch-size 1 \ - --seq-length ${SEQ_LEN} \ - --decoder-seq-length ${DECODER_SEQ_LEN} \ - --out-seq-length 128 \ - --temperature 1.0 \ - --img-h 448 \ - --img-w 448 \ - --patch-dim 14 \ - --seed 153 \ - --top_k 1 \ - --no-load-rng \ - --no-load-optim \ - --input-image-path ${INPUT_IMAGE_PATH} \ - --num-partitions ${NUM_PARTITIONS} \ - --partition-id ${PARTITION_ID} \ - --output-path ${OUTPUT_PATH} \ - --gt-path ${GROUNDTRUTH_PATH} \ - --task ${TASK} \ - ${EXTRA_ARGS} \ - --special-tokens "" "" "" \ - --vision-model-type siglip \ - --ckpt-format torch -done +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 +export TOKENIZERS_PARALLELISM="false" + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +while [[ $# -gt 0 ]]; do + case $1 in + -i|--input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + -t|--task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + + +SEQ_LEN=256 +DECODER_SEQ_LEN=8192 +EXTRA_ARGS=" --pixel-shuffle --use-tiling --max-num-tiles 12 --use-thumbnail" + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --transformer-impl transformer_engine \ + --use-te \ + --use-checkpoint-args \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --language-model-type=qwen2.5_7B \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --group-query-attention \ + --num-query-groups 4 \ + --num-layers 28 \ + --hidden-size 3584 \ + --ffn-hidden-size 18944 \ + --add-qkv-bias \ + --num-attention-heads 28 \ + --max-position-embeddings 32768 \ + --no-masked-softmax-fusion \ + --load ${MODEL_PATH} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2.5-7B-Instruct \ + --tokenizer-prompt-format qwen2p5 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --out-seq-length 128 \ + --temperature 1.0 \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --seed 153 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --task ${TASK} \ + ${EXTRA_ARGS} \ + --special-tokens "" "" "" \ + --vision-model-type siglip \ + --ckpt-format torch +done diff --git a/examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh b/examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh old mode 100644 new mode 100755 diff --git a/examples/multimodal/nvlm/sft_34b_internvit.sh b/examples/multimodal/nvlm/sft_34b_internvit.sh old mode 100644 new mode 100755 index 0dff9461dae1f38255093afc893ad1110bc5ad6b..7cdc854197c9e0ae91eac5f228f0ccb5c88c8f68 --- a/examples/multimodal/nvlm/sft_34b_internvit.sh +++ b/examples/multimodal/nvlm/sft_34b_internvit.sh @@ -1,160 +1,161 @@ -#!/bin/bash - -# Your SBATCH commands here if using SLURM. - -# Please launch this script from megatron-lm root. - -# Train a multimodal model. - -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NCCL_ALGO=^NVLS -export TOKENIZERS_PARALLELISM="false" - - -DEBUG=0 - -if [[ $BATCH -eq 0 ]]; then - DATETIME=`date +'%y-%m-%d-%H-%M-%S'` - MODEL_NAME="mcore-nous-yi34b-internvit-mlp-sft-${DATETIME}" -else - MODEL_NAME="mcore-nous-yi34b-internvit-mlp-sft" -fi - -WORKSPACE="" -SOURCE=`pwd` -OUTPUT_BASE="${WORKSPACE}/output" -OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" - -FINETUNE_DIR=${OUTPUT}/checkpoints -LOGS_DIR="${OUTPUT}/logs" -TENSORBOARD_DIR="${OUTPUT}/tensorboard" - -LOAD_NAME="mcore-nous-yi34b-internvit-mlp" # From pretraining -CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints" - -DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/sft_blend.yaml" - - -if [[ $DEBUG -eq 1 ]]; then - MBZ=1 - BZ=1 - NW=0 - LI=1 - AD=0.0 - HD=0.0 - ALLOW_NONDETERMINISTIC=1 - - # Can run out of GPU memory in interactive memory without this. - # This is just for interactive testing purposes. Do not use for proper training. - EXTRA_ARGS=" --freeze-LM" -else - MBZ=1 - BZ=128 - NW=2 - LI=5 - AD=0.0 - HD=0.0 - ALLOW_NONDETERMINISTIC=1 - - EXTRA_ARGS="" -fi - -SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). -DECODER_SEQ_LEN=3200 # Language model sequence length. -MAX_POS_EMBED=3200 - -OPTIONS=" \ - --swiglu \ - --use-distributed-optimizer \ - --num-workers ${NW} \ - --num-layers 60 \ - --hidden-size 7168 \ - --normalization RMSNorm \ - --num-attention-heads 56 \ - --exit-duration-in-mins 230 \ - --group-query-attention \ - --num-query-groups 8 \ - --ffn-hidden-size 20480 \ - --seq-length ${SEQ_LEN} \ - --decoder-seq-length ${DECODER_SEQ_LEN} \ - --max-position-embeddings ${MAX_POS_EMBED} \ - --tokenizer-type MultimodalTokenizer \ - --tokenizer-model NousResearch/Nous-Hermes-2-Yi-34B \ - --tokenizer-prompt-format nvlm-yi-34b \ - --vocab-size 64000 \ - --make-vocab-size-divisible-by 1 \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 5000000 \ - --disable-bias-linear \ - --tensor-model-parallel-size 8 \ - --language-model-type yi-34b \ - --vision-model-type internvit \ - --micro-batch-size ${MBZ} \ - --global-batch-size ${BZ} \ - --train-samples 30000000 \ - --lr-decay-samples 25600000 \ - --lr-warmup-samples 83200 \ - --lr 2e-6 \ - --min-lr 2.5e-7 \ - --lr-decay-style cosine \ - --split 100,0,0 \ - --clip-grad 10 \ - --weight-decay 0.1 \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --init-method-std 0.014 \ - --attention-dropout ${AD} \ - --hidden-dropout ${HD} \ - --eod-mask-loss \ - --bf16 \ - --tensorboard-dir=${TENSORBOARD_DIR} \ - --freeze-ViT \ - --img-h 448 \ - --img-w 448 \ - --patch-dim 14 \ - --data-path ${DATA_TRAIN} \ - --dataloader-type external \ - --dataloader-save ${FINETUNE_DIR}/dataloader \ - --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ - --log-interval ${LI} \ - --load ${FINETUNE_DIR} \ - --save ${FINETUNE_DIR} \ - --pretrained-checkpoint ${CHECKPOINT_DIR} \ - --save-interval 5000 \ - --eval-interval 500 \ - --eval-iters 10 \ - --log-params-norm \ - --log-num-zeros-in-grad \ - ${EXTRA_ARGS} \ - --disable-vision-class-token \ - --use-te \ - --ckpt-format torch \ - --pixel-shuffle \ - --use-tiling \ - --max-num-tiles 6 \ - --use-thumbnail \ - --use-tile-tags \ - --image-tag-type nvlm - " - -export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} -export NVTE_APPLY_QK_LAYER_SCALING=0 - -# Interactive or batch mode -if [[ $BATCH -eq 0 ]]; then - torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} -else - run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" - - DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` - - srun -l --verbose \ - --container-image \ - --container-mounts "" \ - --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ - sh -c "${run_cmd}" - - set +x -fi +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_ALGO=^NVLS +export TOKENIZERS_PARALLELISM="false" + + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-nous-yi34b-internvit-mlp-sft-${DATETIME}" +else + MODEL_NAME="mcore-nous-yi34b-internvit-mlp-sft" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +LOAD_NAME="mcore-nous-yi34b-internvit-mlp" # From pretraining +CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/sft_blend.yaml" + + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + LI=1 + AD=0.0 + HD=0.0 + ALLOW_NONDETERMINISTIC=1 + + # Can run out of GPU memory in interactive memory without this. + # This is just for interactive testing purposes. Do not use for proper training. + EXTRA_ARGS=" --freeze-LM" +else + MBZ=1 + BZ=128 + NW=2 + LI=5 + AD=0.0 + HD=0.0 + ALLOW_NONDETERMINISTIC=1 + + EXTRA_ARGS="" +fi + +SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +DECODER_SEQ_LEN=3200 # Language model sequence length. +MAX_POS_EMBED=3200 + +OPTIONS=" \ + --swiglu \ + --use-distributed-optimizer \ + --num-workers ${NW} \ + --num-layers 60 \ + --hidden-size 7168 \ + --normalization RMSNorm \ + --num-attention-heads 56 \ + --exit-duration-in-mins 230 \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 20480 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model NousResearch/Nous-Hermes-2-Yi-34B \ + --tokenizer-prompt-format nvlm-yi-34b \ + --vocab-size 64000 \ + --make-vocab-size-divisible-by 1 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 5000000 \ + --disable-bias-linear \ + --tensor-model-parallel-size 8 \ + --language-model-type yi-34b \ + --vision-model-type internvit \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --train-samples 30000000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --lr 2e-6 \ + --min-lr 2.5e-7 \ + --lr-decay-style cosine \ + --split 100,0,0 \ + --clip-grad 10 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --eod-mask-loss \ + --bf16 \ + --tensorboard-dir=${TENSORBOARD_DIR} \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --data-path ${DATA_TRAIN} \ + --dataloader-type external \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --log-interval ${LI} \ + --load ${FINETUNE_DIR} \ + --save ${FINETUNE_DIR} \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --save-interval 5000 \ + --eval-interval 500 \ + --eval-iters 10 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + ${EXTRA_ARGS} \ + --disable-vision-class-token \ + --use-te \ + --ckpt-format torch \ + --pixel-shuffle \ + --use-tiling \ + --max-num-tiles 6 \ + --use-thumbnail \ + --use-tile-tags \ + --image-tag-type nvlm + " + +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} +export NVTE_APPLY_QK_LAYER_SCALING=0 + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh b/examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh old mode 100644 new mode 100755 diff --git a/examples/multimodal/nvlm/sft_qwen2p5_7b_internvit_6b_video.sh b/examples/multimodal/nvlm/sft_qwen2p5_7b_internvit_6b_video.sh new file mode 100755 index 0000000000000000000000000000000000000000..d7eb8e00280505d7c1827ba69de532687e97c7df --- /dev/null +++ b/examples/multimodal/nvlm/sft_qwen2p5_7b_internvit_6b_video.sh @@ -0,0 +1,184 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_ALGO=^NVLS +export TOKENIZERS_PARALLELISM=false + +USER=$SLURM_JOB_USER + +# Auto-detect batch or interactive mode. +which srun +BATCH=$((1-$?)) + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="qwen2.5-7B-internvit-video-sft-nvlm-${DATETIME}" +else + MODEL_NAME="qwen2.5-7B-internvitp-video-sft-nvlm" + DEBUG=0 +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR="${OUTPUT}/checkpoints" +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +# From pretraining. The pretraining checkpoint should have tensor parallel size to 4. +LOAD_NAME="mcore-qwen2p5-7b-internvit-tp4" + +CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/sft_blend.yaml" + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + AD=0.0 + HD=0.0 + LI=1 + # This is just for interactive testing purposes. Do not use for proper training. + EXTRA_ARGS="--freeze-LM" + ALLOW_NONDETERMINISTIC=1 +else + MBZ=1 + BZ=256 + NW=8 + AD=0.0 + HD=0.0 + LI=5 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +fi + +USE_TILING=1 +SEQ_LEN=1024 +DECODER_SEQ_LEN=16384 +MAX_POS_EMBED=32768 +TRAIN_SAMPLES=6602173 +WARMUP_SAMPLES=198065 + + +if [[ $BATCH -eq 0 ]]; then + # Runs out of GPU memory in interactive memory without this. + EXTRA_ARGS+="--freeze-LM" +fi + +if [[ $USE_TILING -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 12 --use-thumbnail" + SEQ_LEN=256 +fi + + +OPTIONS=" \ + --swiglu \ + --use-distributed-optimizer \ + --num-workers ${NW} \ + --num-layers 28 \ + --hidden-size 3584 \ + --norm-epsilon 1e-06 \ + --normalization RMSNorm \ + --num-attention-heads 28 \ + --exit-duration-in-mins 110 \ + --group-query-attention \ + --num-query-groups 4 \ + --ffn-hidden-size 18944 \ + --add-qkv-bias \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --dataloader-seq-length ${DECODER_SEQ_LEN} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2.5-7B-Instruct \ + --tokenizer-prompt-format qwen2p5 \ + --pixel-shuffle \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --disable-bias-linear \ + --pipeline-model-parallel-size 1 \ + --tensor-model-parallel-size 4 \ + --language-model-type qwen2.5_7B \ + --vision-model-type internvit \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --lr 2e-6 \ + --min-lr 2.5e-7 \ + --train-samples ${TRAIN_SAMPLES} \ + --lr-warmup-samples ${WARMUP_SAMPLES} \ + --lr-decay-style cosine \ + --clip-grad 10 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --eod-mask-loss \ + --bf16 \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --data-path ${DATA_TRAIN} \ + --dataloader-type external \ + --split 100,0,0 \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --log-interval ${LI} \ + --save-interval 500 \ + --eval-interval 500 \ + --eval-iters 10 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + ${EXTRA_ARGS} \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --distributed-timeout-minutes 60 \ + --allow-missing-vision-projection-checkpoint \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --disable-vision-class-token \ + --use-te \ + --ckpt-format torch \ + --num-frames 32 \ + --use-checkpoint-args \ + --image-tag-type internvl \ + --recompute-granularity full \ + --recompute-method block \ + --recompute-num-layers 28 \ + --recompute-vision \ +" + + +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} +export NVTE_APPLY_QK_LAYER_SCALING=0 + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/pretrain_mistral_clip.sh b/examples/multimodal/pretrain_mistral_clip.sh old mode 100644 new mode 100755 index 90b0053d19fd3d556d336093afc3414425eb8664..6032a839c80f9e405f2ccffbcc7e6b63f2410b0f --- a/examples/multimodal/pretrain_mistral_clip.sh +++ b/examples/multimodal/pretrain_mistral_clip.sh @@ -1,128 +1,128 @@ -#!/bin/bash -# Pretrain a multimodal model. - -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -MODEL_NAME="mcore-llava-mistral-7b-instruct-clip336-pretraining" - -# Check that the user has set an output path for model checkpoints. -if [[ -z $WORKSPACE ]]; then - echo "Please set WORKSPACE for storing your model checkpoints." - exit 1 -fi - -SOURCE=`pwd` -OUTPUT_BASE="${WORKSPACE}/output" -OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" - -FINETUNE_DIR=${OUTPUT}/checkpoints -LOGS_DIR="${OUTPUT}/logs" -TENSORBOARD_DIR="${OUTPUT}/tensorboard" - -if [[ -z $LOAD_NAME ]]; then - echo "Please set LOAD_NAME for input model name." - exit 1 -fi - -CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" - -DATA_TRAIN="${SOURCE}/examples/multimodal/pretrain_dataset.yaml" - -DEBUG=0 -if [[ $DEBUG -eq 1 ]]; then - BZ=32 - NW=2 - HD=0.0 - LI=1 - EXTRA_ARGS="" - NONDETERMINISTIC_ATTN=1 -else - BZ=256 - NW=2 - HD=0.1 - LI=10 - EXTRA_ARGS="" - NONDETERMINISTIC_ATTN=1 -fi - -OPTIONS=" \ - --apply-layernorm-1p \ - --attention-softmax-in-fp32 \ - --use-checkpoint-args \ - --use-distributed-optimizer \ - --transformer-impl transformer_engine \ - --use-te \ - --normalization RMSNorm \ - --group-query-attention \ - --num-query-groups 8 \ - --no-masked-softmax-fusion \ - --num-workers ${NW} \ - --exit-duration-in-mins 230 \ - --use-flash-attn \ - --untie-embeddings-and-output-weights \ - --disable-bias-linear \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --swiglu \ - --attention-dropout 0.0 \ - --hidden-dropout ${HD} \ - --tensor-model-parallel-size 4 \ - --pipeline-model-parallel-size 1 \ - --num-layers 32 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --seq-length 576 \ - --decoder-seq-length 1024 \ - --max-position-embeddings 4096 \ - --ffn-hidden-size 14336 \ - --train-iters 20000 \ - --micro-batch-size 1 \ - --global-batch-size ${BZ} \ - --lr-decay-iters 20000 \ - --lr-warmup-fraction .01 \ - --lr 0.00015 \ - --min-lr 1.0e-5 \ - --lr-decay-style cosine \ - --log-interval ${LI} \ - --eval-iters 10 \ - --eval-interval 1000 \ - --tokenizer-type MultimodalTokenizer \ - --tokenizer-model mistralai/Mistral-7B-Instruct-v0.3 \ - --tokenizer-prompt-format mistral \ - --data-path ${DATA_TRAIN} \ - --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ - --save-interval 1000 \ - --save ${FINETUNE_DIR} \ - --load ${FINETUNE_DIR} \ - --dataloader-save ${FINETUNE_DIR}/dataloader \ - --pretrained-checkpoint ${CHECKPOINT_DIR} \ - --split 100,0,0 \ - --clip-grad 1.0 \ - --weight-decay 1e-2 \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --init-method-std 0.014 \ - --log-params-norm \ - --log-num-zeros-in-grad \ - --bf16 \ - --eod-mask-loss \ - --freeze-LM \ - --freeze-ViT \ - --patch-dim 14 \ - --img-h 336 \ - --img-w 336 \ - --dataloader-type external \ - --tensorboard-dir ${TENSORBOARD_DIR} \ - --language-model-type=mistral_7b \ - --disable-vision-class-token \ - ${EXTRA_ARGS} \ - --distributed-timeout-minutes 60 \ - --allow-missing-vision-projection-checkpoint \ - --ckpt-format torch -" - -export NVTE_APPLY_QK_LAYER_SCALING=0 -export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN} - +#!/bin/bash +# Pretrain a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +MODEL_NAME="mcore-llava-mistral-7b-instruct-clip336-pretraining" + +# Check that the user has set an output path for model checkpoints. +if [[ -z $WORKSPACE ]]; then + echo "Please set WORKSPACE for storing your model checkpoints." + exit 1 +fi + +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +if [[ -z $LOAD_NAME ]]; then + echo "Please set LOAD_NAME for input model name." + exit 1 +fi + +CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/pretrain_dataset.yaml" + +DEBUG=0 +if [[ $DEBUG -eq 1 ]]; then + BZ=32 + NW=2 + HD=0.0 + LI=1 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 +else + BZ=256 + NW=2 + HD=0.1 + LI=10 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 +fi + +OPTIONS=" \ + --apply-layernorm-1p \ + --attention-softmax-in-fp32 \ + --use-checkpoint-args \ + --use-distributed-optimizer \ + --transformer-impl transformer_engine \ + --use-te \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --num-workers ${NW} \ + --exit-duration-in-mins 230 \ + --use-flash-attn \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout ${HD} \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 576 \ + --decoder-seq-length 1024 \ + --max-position-embeddings 4096 \ + --ffn-hidden-size 14336 \ + --train-iters 20000 \ + --micro-batch-size 1 \ + --global-batch-size ${BZ} \ + --lr-decay-iters 20000 \ + --lr-warmup-fraction .01 \ + --lr 0.00015 \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 1000 \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model mistralai/Mistral-7B-Instruct-v0.3 \ + --tokenizer-prompt-format mistral \ + --data-path ${DATA_TRAIN} \ + --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ + --save-interval 1000 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --split 100,0,0 \ + --clip-grad 1.0 \ + --weight-decay 1e-2 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --bf16 \ + --eod-mask-loss \ + --freeze-LM \ + --freeze-ViT \ + --patch-dim 14 \ + --img-h 336 \ + --img-w 336 \ + --dataloader-type external \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --language-model-type=mistral_7b \ + --disable-vision-class-token \ + ${EXTRA_ARGS} \ + --distributed-timeout-minutes 60 \ + --allow-missing-vision-projection-checkpoint \ + --ckpt-format torch +" + +export NVTE_APPLY_QK_LAYER_SCALING=0 +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN} + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} \ No newline at end of file diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index cbde6680cc26bb9b3b7d4592d466906452064e6d..b4699fe0184d5bec06af735f9367635b6815da73 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -1,515 +1,595 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -"""Generate text using a vision language model.""" -import json -import logging -import os -import sys -from functools import partial - -# Add megatron to the path. -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) -) - -import torch -import yaml -from config import EvaluationConfig -from evaluation.evaluation_datasets import get_evaluation_dataset -from model import model_provider -from multimodal_args import add_multimodal_extra_args - -from megatron.core import parallel_state -from megatron.core.enums import ModelType -from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN -from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings -from megatron.inference.text_generation.api import generate_and_post_process -from megatron.inference.text_generation.forward_step import ForwardStep -from megatron.inference.text_generation.communication import broadcast_int_list -from megatron.training import get_args, get_model, get_tokenizer, print_rank_0 -from megatron.training.checkpointing import load_checkpoint -from megatron.training.initialize import initialize_megatron - - -def add_text_generation_args(parser): - """Text generation arguments.""" - group = parser.add_argument_group(title='Vision language model text generation arguments') - - group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') - group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') - group.add_argument("--top_k", type=int, default=0, help='Top k sampling.') - group.add_argument( - "--out-seq-length", type=int, default=128, help='Length of the output generated text.' - ) - group.add_argument("--output-path", type=str, help='Output file path') - group.add_argument('--input-image-path', type=str, help="Input image directory") - group.add_argument( - '--num-partitions', type=int, default=0, help="Number of partitions for inputs." - ) - group.add_argument('--partition-id', type=int, default=0, help="Partition index") - group.add_argument("--gt-path", type=str, help="Optional ground truth file") - group.add_argument( - "--task", - type=str, - choices=[ - "captioning", - "TextVQA", - "VQAv2", - "ChartQA", - "MMMU", - "VideoMME", - "OCRBench", - "MathVista", - "AI2D", - ], - help="Generation task to run", - ) - group.add_argument( - "--num-samples-per-partition", type=int, default=0, help="Number of samples per partition" - ) - group.add_argument("--config-path", type=str, help="Evaluation config file to use.") - - # Add common multimodal arguments needed for e.g. building the model. - parser = add_multimodal_extra_args(parser) - - return parser - - -def get_evaluation_dataloader( - task, - input_image_path, - gt_path, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - num_samples_per_partition, - num_partitions, - partition_id, - num_frames, - num_workers, - vision_model_type, -): - """Build evaluation dataset.""" - dataset = get_evaluation_dataset( - task, - input_image_path, - gt_path, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - num_samples_per_partition, - num_partitions, - partition_id, - num_frames, - vision_model_type, - ) - - dp_rank = parallel_state.get_data_parallel_rank() - dp_world_size = parallel_state.get_data_parallel_world_size() - - sampler = torch.utils.data.DistributedSampler( - dataset, shuffle=False, num_replicas=dp_world_size, rank=dp_rank - ) - # TODO: Batched inference is not supported yet. - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=None, num_workers=num_workers, sampler=sampler, pin_memory=True - ) - - return dataloader - - -def generate_samples(model, config: EvaluationConfig, print_output): - """Text generation using a trained vision language model.""" - args = get_args() - - dataloader = get_evaluation_dataloader( - config.task, - config.input_image_path, - config.gt_path, - args.img_h, - args.img_w, - args.use_tiling, - args.max_num_tiles, - args.use_thumbnail, - config.num_samples_per_partition, - config.num_partitions, - config.partition_id, - args.num_frames, - args.num_workers, - args.vision_model_type, - ) - - num_img_embeddings_per_tile = get_num_image_embeddings( - args.img_h, - args.img_w, - args.patch_dim, - args.vision_model_type, - args.disable_vision_class_token, - 1, - args.pixel_shuffle, - args.use_tile_tags, - ) - - for idx, (imgs, num_tiles, sample_id, question, answers, metadata) in enumerate(dataloader): - imgs = imgs.to("cuda") - num_tiles = num_tiles.to("cuda") - - conv = get_conversation(config.task, question) - - forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles, args.decoder_seq_length) - - if is_first_rank(): - resp_sentences, _, _, _ = generate_and_post_process( - model, - forward_step=forward_step, - prompts=[conv], - tokens_to_generate=config.out_seq_length, - top_k_sampling=config.top_k, - top_p_sampling=config.top_p, - add_BOS=False, - temperature=config.temperature, - random_seed=args.seed, - detokenize_segments=False, - data_parallel=True, - ) - - for generation in resp_sentences: - if isinstance(sample_id, torch.Tensor): - sample_id = sample_id.item() - - output = {"sample_id": sample_id} - - output_name = "" - if config.task == "captioning": - output_name = "caption" - elif config.task in ( - "TextVQA", - "VQAv2", - "ChartQA", - "OCRBench", - "MathVista", - "AI2D", - ): - output_name = "answer" - elif config.task in ("MMMU"): - output_name = "text" - elif config.task == "VideoMME": - output_name = "response" - output = question - else: - raise NotImplementedError("no output name defined for", config.task) - - prompt, generated = get_prompt_and_generated( - generation, args.tokenizer_prompt_format - ) - if config.task == "VideoMME": - output["questions"][0][output_name] = generated - else: - output["prompt"] = prompt - output[output_name] = generated - - if config.task == "captioning": - output["ground_truth"] = answers - elif config.task in ( - "TextVQA", - "VQAv2", - "ChartQA", - "OCRBench", - "MathVista", - "AI2D", - ): - if isinstance(answers, str): - answers = [answers] - output["gt_answer"] = answers - - if len(metadata) > 0: - output.update(metadata) - elif config.task == "MMMU": - output["prediction"] = generated - output.update(metadata) - else: - raise NotImplementedError("no output processing defined for", config.task) - - if print_output: - print(output) - - yield output - idx += 1 - else: - generate_and_post_process( - model, forward_step=forward_step, detokenize_segments=False, data_parallel=True - ) - - idx += 1 - - -def get_evaluation_config(): - """Get evaluation config from a config file or command-line arguments.""" - args = get_args() - if args.config_path: - with open(args.config_path, "r") as f: - config_dict = yaml.safe_load(f) - - config = EvaluationConfig(**config_dict) - else: - config = EvaluationConfig( - task=args.task, - temperature=args.temperature, - top_p=args.top_p, - top_k=args.top_k, - out_seq_length=args.out_seq_length, - output_path=args.output_path, - input_image_path=args.input_image_path, - gt_path=args.gt_path, - num_partitions=args.num_partitions, - partition_id=args.partition_id, - num_samples_per_partition=args.num_samples_per_partition, - ) - - # Default output path if not defined... - if not config.output_path: - os.makedirs("generated", exist_ok=True) - config.output_path = "generated/" + args.language_model_type - - return config - - -def is_first_rank(): - """First tensor and pipeline parallel rank.""" - return ( - parallel_state.is_pipeline_first_stage(ignore_virtual=True) - and parallel_state.get_tensor_model_parallel_rank() == 0 - ) - - -def get_output_path(config, dp_rank): - """Generation output path.""" - return ( - f"{config.output_path}-{config.task}-dprank={dp_rank}-partition={config.partition_id}.jsonl" - ) - - -def generate_and_write_samples(model, config, print_output=True): - """Generate text and write to an output file.""" - dp_rank = parallel_state.get_data_parallel_rank() - - if is_first_rank(): - output_path = get_output_path(config, dp_rank) - output_file = open(output_path, "w") - print(f"output path: {output_file.name}") - - with torch.no_grad(): - for output in generate_samples(model, config, print_output): - if is_first_rank(): - output_file.write(json.dumps(output) + "\n") - output_file.flush() - - if is_first_rank(): - output_file.close() - - -class VLMForwardStep(ForwardStep): - """Inference forward step for a multimodal model.""" - - def __init__( - self, - num_img_embeddings_per_tile, - images, - num_tiles, - decoder_seq_length, - model, - max_batch_size, - max_sequence_length, - ): - """Create multimodal forward step.""" - total_num_tiles = torch.sum(num_tiles).item() - num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles - - super().__init__(model, max_batch_size, max_sequence_length + num_img_embeddings) - self._images = images - self._num_tiles = num_tiles - self._num_img_embeddings = num_img_embeddings - self.decoder_seq_length = decoder_seq_length - - self._recv_only_vision_embeds = False - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - # Checks if the previous stage only has a vision encoder, and that the current stage has part of the LM decoder. - # In this case, the current stage should only receive vision embeddings. - if pp_rank > 0: - self._recv_only_vision_embeds = parallel_state.is_inside_encoder(pp_rank - 1) and (not parallel_state.is_inside_decoder(pp_rank - 1)) and parallel_state.is_inside_decoder() - - # Checks if the current stage only has a vision encoder - self._encoder_only = parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder() - - def _forward(self, tokens, position_ids, attention_mask): - return self.model( - self._images, - tokens, - position_ids, - attention_mask=None, - inference_params=self.inference_params, - num_image_tiles=self._num_tiles, - runtime_gather_output=True, - ) - - def __call__(self, tokens, position_ids, attention_mask): - num_image_tokens = (tokens == self.model.module.image_token_index).sum().item() - num_tokens = tokens.size(1) - recv_buffer_seq_length = None - if num_image_tokens > 0: - # When there are image tokens and this stage only receives vision embeddings, adjust the recv buffer seq length to match the image embeddings sequence length. - # If there are image tokens and this stage receives full embeddings, make sure we compensate for expansion of image tokens. - # Note that this will set a recv_buffer_seq_length for the encoder stage, this length is irrelevant since that recv buffer is never allocated. - if self._recv_only_vision_embeds: - recv_buffer_seq_length = self._num_img_embeddings - else: - recv_buffer_seq_length = min(self._num_img_embeddings + num_tokens - num_image_tokens, self.decoder_seq_length) - elif self._recv_only_vision_embeds: - # If this stage only receives vision embeddings and there are no image tokens we won't run the encoder and therefore shouldn't try to recv. - recv_buffer_seq_length = 0 - - # If the pipeline stage only has a vision encoder, then it only needs to run when there are image tokens - if not (self._encoder_only and num_image_tokens == 0): - output = super().__call__(tokens, position_ids, attention_mask, recv_buffer_seq_length=recv_buffer_seq_length) - else: - output = None - if isinstance(output, tuple): - logits, _ = output - else: - logits = output - - # On the first inference iteration, we compute image tokens. - # On every PP stage(although inference params should only matter for decoder), - # update the sequence length offset by the number of image tokens. - if num_tokens > 1 and num_image_tokens > 0: - if "image_tokens_count" not in self.inference_params.key_value_memory_dict: - self.inference_params.key_value_memory_dict["image_tokens_count"] = self._num_img_embeddings - - if self._num_img_embeddings + num_tokens - num_image_tokens > self.decoder_seq_length: - self.inference_params.sequence_len_offset += self.decoder_seq_length - num_tokens - else: - self.inference_params.sequence_len_offset += ( - self.inference_params.key_value_memory_dict["image_tokens_count"] - num_image_tokens - ) - - return logits - - -def get_conversation(task, question): - """Get a conversation for a given task and evaluation question.""" - conversation = [] - - # In all cases, the tokenizer adds possible header tokens for the assistant. - if task == "captioning": - conversation = [ - {"role": "system", "content": "Answer the questions."}, - { - "role": "user", - "content": f"{IMAGE_TOKEN}\nProvide a one-sentence caption for provided image.", - }, - ] - elif task in ("TextVQA", "VQAv2", "ChartQA"): - conversation = [ - {"role": "system", "content": "Answer the questions."}, - { - "role": "user", - "content": f"{IMAGE_TOKEN}\n{question}\nAnswer the question using a single word or phrase.", - }, - ] - elif task in ("OCRBench", "MathVista", "AI2D"): - conversation = [ - {"role": "system", "content": "Answer the questions."}, - {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"}, - ] - elif task == "MMMU": - conversation = [ - {"role": "system", "content": "Answer the questions."}, - {"role": "user", "content": question}, - ] - elif task == "VideoMME": - q = ( - "Select the best answer to the following multiple-choice " - "question based on the video. Respond with only the letter " - "(A, B, C, or D) of the correct option.\n" - ) - q += question["questions"][0]["question"] + "\n" - q += question["questions"][0]["choices"][0] + "\n" - q += question["questions"][0]["choices"][1] + "\n" - q += question["questions"][0]["choices"][2] + "\n" - q += question["questions"][0]["choices"][3] + "\n" - - conversation = [ - {"role": "system", "content": "Answer the questions."}, - {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"}, - ] - - return conversation - - -def get_prompt_and_generated(prompt_and_generation, prompt_format): - """Strip prompt and other unnecessary text from generation.""" - if prompt_format == "llama3": - splitted = prompt_and_generation.split("<|start_header_id|>assistant<|end_header_id|>\n\n") - prompt = splitted[0] - generated = splitted[1] - generated = generated.split("<|eot_id|>")[0] - elif prompt_format == "mistral": - splitted = prompt_and_generation.split("[/INST]") - prompt = splitted[0] - generated = splitted[1] - generated = generated.split("")[0] - elif prompt_format == "chatml": - splitted = prompt_and_generation.split("<|im_start|> assistant\n") - prompt = splitted[0] - generated = splitted[1] - generated = generated.split("<|im_end|>")[0] - elif prompt_format in ("nvlm-yi-34b", "qwen2p0", "qwen2p5"): - splitted = prompt_and_generation.split("<|im_start|>assistant\n") - prompt = splitted[0] - generated = splitted[1] - generated = generated.split("<|im_end|>")[0] - else: - raise ValueError(f"Prompt format {prompt_format} is not supported.") - - # Remove possible garbage. - generated = generated.strip() - generated = generated.split("\n\n")[0] - generated = generated.split("\n")[0] - - return prompt, generated - - -def main(): - """Vision language model text generation.""" - initialize_megatron(extra_args_provider=add_text_generation_args) - - if torch.distributed.get_rank() == 0: - logging.getLogger(__name__).warning( - "Models using pipeline parallelism are not supported yet." - ) - - args = get_args() - - def wrapped_model_provider(pre_process, post_process, add_encoder, add_decoder): - return model_provider(pre_process, post_process, add_encoder, add_decoder, parallel_output=False) - - # Set up model and load checkpoint. - model = get_model(wrapped_model_provider, model_type=ModelType.encoder_and_decoder, wrap_with_ddp=False) - - if args.load is not None: - _ = load_checkpoint(model, None, None) - - model = model[0] - - model.eval() - - config = get_evaluation_config() - - generate_and_write_samples(model, config) - - -if __name__ == "__main__": - main() +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Generate text using a vision language model.""" +import json +import logging +import os +import sys +from functools import partial + +# Add megatron to the path. +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) +) + +import torch +import yaml +from config import EvaluationConfig +from evaluation.evaluation_datasets import get_evaluation_dataset +from model import model_provider +from multimodal_args import add_multimodal_extra_args + +from megatron.core import parallel_state +from megatron.core.enums import ModelType +from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN +from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings +from megatron.inference.text_generation.api import generate_and_post_process +from megatron.inference.text_generation.forward_step import ForwardStep +from megatron.inference.text_generation.communication import broadcast_int_list +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.engines.mcore_engine import MCoreEngine +from megatron.core.inference.inference_request import InferenceRequest, VLMInferenceRequest +from megatron.core.inference.text_generation_controllers.vlm_text_generation_controller import ( + VLMTextGenerationController, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.inference.model_inference_wrappers.multimodal.vlm_inference_wrapper import ( + VLMInferenceWrapper, +) +from megatron.training import get_args, get_model, get_tokenizer, print_rank_0 +from megatron.training.checkpointing import load_checkpoint +from megatron.training.initialize import initialize_megatron + + +def add_text_generation_args(parser): + """Text generation arguments.""" + group = parser.add_argument_group(title='Vision language model text generation arguments') + + group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') + group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') + group.add_argument("--top_k", type=int, default=0, help='Top k sampling.') + group.add_argument( + "--out-seq-length", type=int, default=128, help='Length of the output generated text.' + ) + group.add_argument("--output-path", type=str, help='Output file path') + group.add_argument('--input-image-path', type=str, help="Input image directory") + group.add_argument( + '--num-partitions', type=int, default=0, help="Number of partitions for inputs." + ) + group.add_argument('--partition-id', type=int, default=0, help="Partition index") + group.add_argument("--gt-path", type=str, help="Optional ground truth file") + group.add_argument( + "--task", + type=str, + choices=[ + "captioning", + "TextVQA", + "VQAv2", + "ChartQA", + "MMMU", + "VideoMME", + "OCRBench", + "MathVista", + "AI2D", + "InfoVQA", + "SPDocVQA", + ], + help="Generation task to run", + ) + group.add_argument( + "--num-samples-per-partition", type=int, default=0, help="Number of samples per partition" + ) + group.add_argument("--config-path", type=str, help="Evaluation config file to use.") + + group.add_argument("--use-mcore-inference", action="store_true", default=False, help="Use the MCore inference API") + + # Add common multimodal arguments needed for e.g. building the model. + parser = add_multimodal_extra_args(parser) + + return parser + + +def get_evaluation_dataloader( + task, + input_image_path, + gt_path, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_samples_per_partition, + num_partitions, + partition_id, + num_frames, + num_workers, + vision_model_type, +): + """Build evaluation dataset.""" + dataset = get_evaluation_dataset( + task, + input_image_path, + gt_path, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_samples_per_partition, + num_partitions, + partition_id, + num_frames, + vision_model_type, + ) + + dp_rank = parallel_state.get_data_parallel_rank() + dp_world_size = parallel_state.get_data_parallel_world_size() + + sampler = torch.utils.data.DistributedSampler( + dataset, shuffle=False, num_replicas=dp_world_size, rank=dp_rank + ) + # TODO: Batched inference is not supported yet. + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=None, num_workers=num_workers, sampler=sampler, pin_memory=True + ) + + return dataloader + + +def generate_samples(model, config: EvaluationConfig, print_output): + """Text generation using a trained vision language model.""" + args = get_args() + + dataloader = get_evaluation_dataloader( + config.task, + config.input_image_path, + config.gt_path, + args.img_h, + args.img_w, + args.use_tiling, + args.max_num_tiles, + args.use_thumbnail, + config.num_samples_per_partition, + config.num_partitions, + config.partition_id, + args.num_frames, + args.num_workers, + args.vision_model_type, + ) + + num_img_embeddings_per_tile = get_num_image_embeddings( + args.img_h, + args.img_w, + args.patch_dim, + args.vision_model_type, + args.disable_vision_class_token, + 1, + args.pixel_shuffle, + args.use_tile_tags, + ) + + if args.use_mcore_inference: + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=args.hidden_size, + inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold, + fp32_residual_connection=args.fp32_residual_connection, + params_dtype=args.params_dtype, + padded_vocab_size=args.padded_vocab_size, + ) + inference_wrapped_model = VLMInferenceWrapper(model, inference_wrapper_config) + tokenizer = get_tokenizer() + controller = VLMTextGenerationController( + inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer + ) + inference_engine = MCoreEngine( + controller, max_batch_size=1, random_seed=args.seed + ) + sampling_params = SamplingParams( + temperature=config.temperature, + top_k=config.top_k, + top_p=config.top_p, + num_tokens_to_generate=config.out_seq_length, + ) + + for idx, (imgs, num_tiles, sample_id, question, answers, metadata) in enumerate(dataloader): + imgs = imgs.to("cuda") + num_tiles = num_tiles.to("cuda") + + conv = get_conversation(config.task, question) + + if not args.use_mcore_inference: + forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles, args.decoder_seq_length) + + + if is_first_rank(): + + if args.use_mcore_inference: + inference_request = VLMInferenceRequest( + request_id=inference_engine.get_new_request_id(), + prompt=conv, + prompt_tokens=controller.tokenize_prompt(conv), + inference_parameters=sampling_params, + num_img_embeddings_per_tile=num_img_embeddings_per_tile, + imgs=imgs, + num_tiles=num_tiles, + decoder_seq_length=args.decoder_seq_length, + ) + results: List[InferenceRequest] = inference_engine.generate( + inference_requests=[inference_request] + ) + + resp_sentences = [ + tokenizer.detokenize(result.prompt_tokens) + result.generated_text + for result in results + ] + else: + resp_sentences, _, _, _ = generate_and_post_process( + model, + forward_step=forward_step, + prompts=[conv], + tokens_to_generate=config.out_seq_length, + top_k_sampling=config.top_k, + top_p_sampling=config.top_p, + add_BOS=False, + temperature=config.temperature, + random_seed=args.seed, + detokenize_segments=False, + data_parallel=True, + ) + + for generation in resp_sentences: + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.item() + + output = {"sample_id": sample_id} + + output_name = "" + if config.task == "captioning": + output_name = "caption" + elif config.task in ( + "TextVQA", + "VQAv2", + "ChartQA", + "OCRBench", + "MathVista", + "AI2D", + "InfoVQA", + "SPDocVQA", + ): + output_name = "answer" + elif config.task in ("MMMU"): + output_name = "text" + elif config.task == "VideoMME": + output_name = "response" + output = question + else: + raise NotImplementedError("no output name defined for", config.task) + + prompt, generated = get_prompt_and_generated( + generation, args.tokenizer_prompt_format + ) + if config.task == "VideoMME": + output["questions"][0][output_name] = generated + else: + output["prompt"] = prompt + output[output_name] = generated + + if config.task == "captioning": + output["ground_truth"] = answers + elif config.task in ( + "TextVQA", + "VQAv2", + "ChartQA", + "OCRBench", + "MathVista", + "AI2D", + "InfoVQA", + "SPDocVQA", + ): + if isinstance(answers, str): + answers = [answers] + output["gt_answer"] = answers + + if len(metadata) > 0: + output.update(metadata) + elif config.task == "MMMU": + output["prediction"] = generated + output.update(metadata) + else: + raise NotImplementedError("no output processing defined for", config.task) + + if print_output: + print(output) + + yield output + idx += 1 + else: + if args.use_mcore_inference: + inference_request = VLMInferenceRequest( + request_id=inference_engine.get_new_request_id(), + prompt=conv, + prompt_tokens=controller.tokenize_prompt(conv), + inference_parameters=sampling_params, + num_img_embeddings_per_tile=num_img_embeddings_per_tile, + imgs=imgs, + num_tiles=num_tiles, + decoder_seq_length=args.decoder_seq_length, + ) + inference_engine.generate( + inference_requests=[inference_request] + ) + else: + generate_and_post_process( + model, forward_step=forward_step, detokenize_segments=False, data_parallel=True + ) + + idx += 1 + + +def get_evaluation_config(): + """Get evaluation config from a config file or command-line arguments.""" + args = get_args() + if args.config_path: + with open(args.config_path, "r") as f: + config_dict = yaml.safe_load(f) + + config = EvaluationConfig(**config_dict) + else: + config = EvaluationConfig( + task=args.task, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + out_seq_length=args.out_seq_length, + output_path=args.output_path, + input_image_path=args.input_image_path, + gt_path=args.gt_path, + num_partitions=args.num_partitions, + partition_id=args.partition_id, + num_samples_per_partition=args.num_samples_per_partition, + ) + + # Default output path if not defined... + if not config.output_path: + os.makedirs("generated", exist_ok=True) + config.output_path = "generated/" + args.language_model_type + + return config + + +def is_first_rank(): + """First tensor and pipeline parallel rank.""" + return ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + and parallel_state.get_tensor_model_parallel_rank() == 0 + ) + + +def get_output_path(config, dp_rank): + """Generation output path.""" + return ( + f"{config.output_path}-{config.task}-dprank={dp_rank}-partition={config.partition_id}.jsonl" + ) + + +def generate_and_write_samples(model, config, print_output=True): + """Generate text and write to an output file.""" + dp_rank = parallel_state.get_data_parallel_rank() + + if is_first_rank(): + output_path = get_output_path(config, dp_rank) + output_file = open(output_path, "w") + print(f"output path: {output_file.name}") + + with torch.no_grad(): + for output in generate_samples(model, config, print_output): + if is_first_rank(): + output_file.write(json.dumps(output) + "\n") + output_file.flush() + + if is_first_rank(): + output_file.close() + +class VLMForwardStep(ForwardStep): + """Inference forward step for a multimodal model.""" + + def __init__( + self, + num_img_embeddings_per_tile, + images, + num_tiles, + decoder_seq_length, + model, + max_batch_size, + max_sequence_length, + ): + """Create multimodal forward step.""" + total_num_tiles = torch.sum(num_tiles).item() + num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles + + super().__init__(model, max_batch_size, max_sequence_length + num_img_embeddings) + self._images = images + self._num_tiles = num_tiles + self._num_img_embeddings = num_img_embeddings + self.decoder_seq_length = decoder_seq_length + + self._recv_only_vision_embeds = False + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + # Checks if the previous stage only has a vision encoder, and that the current stage has part of the LM decoder. + # In this case, the current stage should only receive vision embeddings. + if pp_rank > 0: + self._recv_only_vision_embeds = parallel_state.is_inside_encoder(pp_rank - 1) and (not parallel_state.is_inside_decoder(pp_rank - 1)) and parallel_state.is_inside_decoder() + + # Checks if the current stage only has a vision encoder + self._encoder_only = parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder() + + def _forward(self, tokens, position_ids, attention_mask): + return self.model( + self._images, + tokens, + position_ids, + attention_mask=None, + inference_params=self.inference_params, + num_image_tiles=self._num_tiles, + runtime_gather_output=True, + ) + + def __call__(self, tokens, position_ids, attention_mask): + num_image_tokens = (tokens == self.model.module.image_token_index).sum().item() + num_tokens = tokens.size(1) + recv_buffer_seq_length = None + if num_image_tokens > 0: + # When there are image tokens and this stage only receives vision embeddings, adjust the recv buffer seq length to match the image embeddings sequence length. + # If there are image tokens and this stage receives full embeddings, make sure we compensate for expansion of image tokens. + # Note that this will set a recv_buffer_seq_length for the encoder stage, this length is irrelevant since that recv buffer is never allocated. + if self._recv_only_vision_embeds: + recv_buffer_seq_length = self._num_img_embeddings + else: + recv_buffer_seq_length = min(self._num_img_embeddings + num_tokens - num_image_tokens, self.decoder_seq_length) + elif self._recv_only_vision_embeds: + # If this stage only receives vision embeddings and there are no image tokens we won't run the encoder and therefore shouldn't try to recv. + recv_buffer_seq_length = 0 + + # If the pipeline stage only has a vision encoder, then it only needs to run when there are image tokens + if not (self._encoder_only and num_image_tokens == 0): + output = super().__call__(tokens, position_ids, attention_mask, recv_buffer_seq_length=recv_buffer_seq_length) + else: + output = None + if isinstance(output, tuple): + logits, _ = output + else: + logits = output + + # On the first inference iteration, we compute image tokens. + # On every PP stage(although inference params should only matter for decoder), + # update the sequence length offset by the number of image tokens. + if num_tokens > 1 and num_image_tokens > 0: + if "image_tokens_count" not in self.inference_params.key_value_memory_dict: + self.inference_params.key_value_memory_dict["image_tokens_count"] = self._num_img_embeddings + + if self._num_img_embeddings + num_tokens - num_image_tokens > self.decoder_seq_length: + self.inference_params.sequence_len_offset += self.decoder_seq_length - num_tokens + else: + self.inference_params.sequence_len_offset += ( + self.inference_params.key_value_memory_dict["image_tokens_count"] - num_image_tokens + ) + + return logits + + +def get_conversation(task, question): + """Get a conversation for a given task and evaluation question.""" + conversation = [] + + # In all cases, the tokenizer adds possible header tokens for the assistant. + if task == "captioning": + conversation = [ + {"role": "system", "content": "Answer the questions."}, + { + "role": "user", + "content": f"{IMAGE_TOKEN}\nProvide a one-sentence caption for provided image.", + }, + ] + elif task in ("TextVQA", "VQAv2", "ChartQA", "InfoVQA", "SPDocVQA"): + conversation = [ + {"role": "system", "content": "Answer the questions."}, + { + "role": "user", + "content": f"{IMAGE_TOKEN}\n{question}\nAnswer the question using a single word or phrase.", + }, + ] + elif task in ("OCRBench", "MathVista", "AI2D"): + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"}, + ] + elif task == "MMMU": + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": question}, + ] + elif task == "VideoMME": + q = ( + "Select the best answer to the following multiple-choice " + "question based on the video. Respond with only the letter " + "(A, B, C, or D) of the correct option.\n" + ) + q += question["questions"][0]["question"] + "\n" + q += question["questions"][0]["choices"][0] + "\n" + q += question["questions"][0]["choices"][1] + "\n" + q += question["questions"][0]["choices"][2] + "\n" + q += question["questions"][0]["choices"][3] + "\n" + + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": f"{IMAGE_TOKEN}\n{q}"}, + ] + + return conversation + + +def get_prompt_and_generated(prompt_and_generation, prompt_format): + """Strip prompt and other unnecessary text from generation.""" + if prompt_format in ("llama3", "llama3p1"): + splitted = prompt_and_generation.split("<|start_header_id|>assistant<|end_header_id|>\n\n") + prompt = splitted[0] + generated = splitted[1] + generated = generated.split("<|eot_id|>")[0] + elif prompt_format == "mistral": + splitted = prompt_and_generation.split("[/INST]") + prompt = splitted[0] + generated = splitted[1] + generated = generated.split("")[0] + elif prompt_format == "chatml": + splitted = prompt_and_generation.split("<|im_start|> assistant\n") + prompt = splitted[0] + generated = splitted[1] + generated = generated.split("<|im_end|>")[0] + elif prompt_format in ("nvlm-yi-34b", "qwen2p0", "qwen2p5"): + splitted = prompt_and_generation.split("<|im_start|>assistant\n") + prompt = splitted[0] + generated = splitted[1] + generated = generated.split("<|im_end|>")[0] + else: + raise ValueError(f"Prompt format {prompt_format} is not supported.") + + # Remove possible garbage. + generated = generated.strip() + generated = generated.split("\n\n")[0] + generated = generated.split("\n")[0] + + return prompt, generated + + +def main(): + """Vision language model text generation.""" + initialize_megatron(extra_args_provider=add_text_generation_args) + + if torch.distributed.get_rank() == 0: + logging.getLogger(__name__).warning( + "Models using pipeline parallelism are not supported yet." + ) + + args = get_args() + + def wrapped_model_provider(pre_process, post_process, add_encoder, add_decoder): + return model_provider(pre_process, post_process, add_encoder, add_decoder, parallel_output=False) + + # Set up model and load checkpoint. + model = get_model(wrapped_model_provider, model_type=ModelType.encoder_and_decoder, wrap_with_ddp=False) + + if args.load is not None: + _ = load_checkpoint(model, None, None) + + model = model[0] + + model.eval() + + config = get_evaluation_config() + + generate_and_write_samples(model, config) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/sft_mistral_clip.sh b/examples/multimodal/sft_mistral_clip.sh old mode 100644 new mode 100755 index 94ff208eb4df632b597daa49bc3a1fbff62fe8d1..57e6d46393a13bd4d604fdf00924a9e8111a194e --- a/examples/multimodal/sft_mistral_clip.sh +++ b/examples/multimodal/sft_mistral_clip.sh @@ -1,130 +1,130 @@ -#!/bin/bash -# Run SFT on a pretrained multimodal model - -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -MODEL_NAME="mcore-llava-mistral-7b-instruct-clip336-sft" - -# Check that the user has set an output path for model checkpoints. -if [[ -z $WORKSPACE ]]; then - echo "Please set WORKSPACE for storing your model checkpoints." - exit 1 -fi - -SOURCE=`pwd` -OUTPUT_BASE="${WORKSPACE}/output" -OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" - -FINETUNE_DIR=${OUTPUT}/checkpoints -LOGS_DIR="${OUTPUT}/logs" -TENSORBOARD_DIR="${OUTPUT}/tensorboard" - -if [[ -z $LOAD_NAME ]]; then - echo "Please set LOAD_NAME for input model name." - exit 1 -fi - -if [[ -z $LOAD_ITER ]]; then - echo "Please set LOAD_ITER for pre-trained input model iteration." - exit 1 -fi - -CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" - -DATA_TRAIN="${SOURCE}/examples/multimodal/sft_dataset.yaml" - -DEBUG=0 -if [[ $DEBUG -eq 1 ]]; then - BZ=8 - NW=1 - HD=0.0 - LI=1 - EXTRA_ARGS="" - NONDETERMINISTIC_ATTN=1 -else - BZ=128 - NW=2 - HD=0.1 - LI=10 - EXTRA_ARGS="" - NONDETERMINISTIC_ATTN=1 -fi - -OPTIONS=" \ - --apply-layernorm-1p \ - --attention-softmax-in-fp32 \ - --use-checkpoint-args \ - --use-distributed-optimizer \ - --transformer-impl transformer_engine \ - --use-te \ - --normalization RMSNorm \ - --group-query-attention \ - --num-query-groups 8 \ - --no-masked-softmax-fusion \ - --num-workers ${NW} \ - --exit-duration-in-mins 230 \ - --use-flash-attn \ - --untie-embeddings-and-output-weights \ - --disable-bias-linear \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --swiglu \ - --attention-dropout 0.0 \ - --hidden-dropout ${HD} \ - --tensor-model-parallel-size 4 \ - --pipeline-model-parallel-size 1 \ - --num-layers 32 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --seq-length 576 \ - --decoder-seq-length 2048 \ - --max-position-embeddings 4096 \ - --ffn-hidden-size 14336 \ - --train-iters 20000 \ - --micro-batch-size 1 \ - --global-batch-size ${BZ} \ - --lr-decay-iters 20000 \ - --lr-warmup-fraction .01 \ - --lr 1e-6 \ - --min-lr 1e-7 \ - --lr-decay-style cosine \ - --log-interval ${LI} \ - --eval-iters 10 \ - --eval-interval 500 \ - --tokenizer-type MultimodalTokenizer \ - --tokenizer-model mistralai/Mistral-7B-Instruct-v0.3 \ - --tokenizer-prompt-format mistral \ - --data-path ${DATA_TRAIN} \ - --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ - --save-interval 500 \ - --save ${FINETUNE_DIR} \ - --load ${FINETUNE_DIR} \ - --pretrained-checkpoint ${CHECKPOINT_DIR} \ - --dataloader-save ${FINETUNE_DIR}/dataloader \ - --split 100,0,0 \ - --clip-grad 0.5 \ - --weight-decay 0.1 \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --init-method-std 0.014 \ - --log-params-norm \ - --log-num-zeros-in-grad \ - --eod-mask-loss \ - --freeze-ViT \ - --patch-dim 14 \ - --img-h 336 \ - --img-w 336 \ - --dataloader-type external \ - --tensorboard-dir ${TENSORBOARD_DIR} \ - --language-model-type=mistral_7b \ - --disable-vision-class-token \ - ${EXTRA_ARGS} \ - --distributed-timeout-minutes 60 \ - --ckpt-format torch -" - -export NVTE_APPLY_QK_LAYER_SCALING=0 -export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN} - -torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +#!/bin/bash +# Run SFT on a pretrained multimodal model + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +MODEL_NAME="mcore-llava-mistral-7b-instruct-clip336-sft" + +# Check that the user has set an output path for model checkpoints. +if [[ -z $WORKSPACE ]]; then + echo "Please set WORKSPACE for storing your model checkpoints." + exit 1 +fi + +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +if [[ -z $LOAD_NAME ]]; then + echo "Please set LOAD_NAME for input model name." + exit 1 +fi + +if [[ -z $LOAD_ITER ]]; then + echo "Please set LOAD_ITER for pre-trained input model iteration." + exit 1 +fi + +CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/sft_dataset.yaml" + +DEBUG=0 +if [[ $DEBUG -eq 1 ]]; then + BZ=8 + NW=1 + HD=0.0 + LI=1 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 +else + BZ=128 + NW=2 + HD=0.1 + LI=10 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 +fi + +OPTIONS=" \ + --apply-layernorm-1p \ + --attention-softmax-in-fp32 \ + --use-checkpoint-args \ + --use-distributed-optimizer \ + --transformer-impl transformer_engine \ + --use-te \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --num-workers ${NW} \ + --exit-duration-in-mins 230 \ + --use-flash-attn \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout ${HD} \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 576 \ + --decoder-seq-length 2048 \ + --max-position-embeddings 4096 \ + --ffn-hidden-size 14336 \ + --train-iters 20000 \ + --micro-batch-size 1 \ + --global-batch-size ${BZ} \ + --lr-decay-iters 20000 \ + --lr-warmup-fraction .01 \ + --lr 1e-6 \ + --min-lr 1e-7 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 500 \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model mistralai/Mistral-7B-Instruct-v0.3 \ + --tokenizer-prompt-format mistral \ + --data-path ${DATA_TRAIN} \ + --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ + --save-interval 500 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --split 100,0,0 \ + --clip-grad 0.5 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --eod-mask-loss \ + --freeze-ViT \ + --patch-dim 14 \ + --img-h 336 \ + --img-w 336 \ + --dataloader-type external \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --language-model-type=mistral_7b \ + --disable-vision-class-token \ + ${EXTRA_ARGS} \ + --distributed-timeout-minutes 60 \ + --ckpt-format torch +" + +export NVTE_APPLY_QK_LAYER_SCALING=0 +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN} + +torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} diff --git a/examples/multimodal/text_generation_mistral_clip.sh b/examples/multimodal/text_generation_mistral_clip.sh old mode 100644 new mode 100755 index c1ef7bcee897812fef976531e2a5bba961141b42..ed12b38af1fdeccd5818a98e12a2d5d98b75b908 --- a/examples/multimodal/text_generation_mistral_clip.sh +++ b/examples/multimodal/text_generation_mistral_clip.sh @@ -1,109 +1,109 @@ -#!/bin/bash - -export NCCL_IB_SL=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NVTE_APPLY_QK_LAYER_SCALING=0 - -INPUT_IMAGE_PATH="placeholder" -GROUNDTRUTH_PATH="placeholder" -NUM_FRAMES=1 - -while [[ $# -gt 0 ]]; do - case $1 in - -i|--input-image-path) - INPUT_IMAGE_PATH="$2" - shift - shift - ;; - --num-frames) - NUM_FRAMES="$2" - shift - shift - ;; - -o|--output-path) - OUTPUT_PATH="$2" - shift - shift - ;; - -m|--model-path) - MODEL_PATH="$2" - shift - shift - ;; - -t|--task) - TASK="$2" - shift - shift - ;; - -g|--gt-path) - GROUNDTRUTH_PATH="$2" - shift - shift - ;; - -*|--*) - echo "Invalid option $1" - exit 1 - ;; - esac -done - -# Please modify these as needed. -NUM_PARTITIONS=0 -START=0 -END=0 - -for PARTITION_ID in $( eval echo {$START..$END} ) -do - torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ - --apply-layernorm-1p \ - --attention-softmax-in-fp32 \ - --use-flash-attn \ - --transformer-impl transformer_engine \ - --use-te \ - --use-checkpoint-args \ - --normalization RMSNorm \ - --language-model-type mistral_7b \ - --untie-embeddings-and-output-weights \ - --disable-bias-linear \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --swiglu \ - --attention-dropout 0.0 \ - --hidden-dropout 0.0 \ - --tensor-model-parallel-size 4 \ - --pipeline-model-parallel-size 1 \ - --group-query-attention \ - --num-query-groups 8 \ - --num-layers 32 \ - --hidden-size 4096 \ - --ffn-hidden-size 14336 \ - --num-attention-heads 32 \ - --max-position-embeddings 4096 \ - --no-masked-softmax-fusion \ - --load ${MODEL_PATH} \ - --tokenizer-type MultimodalTokenizer \ - --tokenizer-model mistralai/Mistral-7B-Instruct-v0.3 \ - --tokenizer-prompt-format mistral \ - --bf16 \ - --micro-batch-size 1 \ - --seq-length 2048 \ - --out-seq-length 12 \ - --temperature 1.0 \ - --img-h 336 \ - --img-w 336 \ - --patch-dim 14 \ - --seed 153 \ - --top_k 1 \ - --no-load-rng \ - --no-load-optim \ - --input-image-path ${INPUT_IMAGE_PATH} \ - --num-partitions ${NUM_PARTITIONS} \ - --partition-id ${PARTITION_ID} \ - --output-path ${OUTPUT_PATH} \ - --gt-path ${GROUNDTRUTH_PATH} \ - --task ${TASK} \ - --disable-vision-class-token \ - --num-frames ${NUM_FRAMES} \ - --ckpt-format torch -done +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" +NUM_FRAMES=1 + +while [[ $# -gt 0 ]]; do + case $1 in + -i|--input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + --num-frames) + NUM_FRAMES="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + -t|--task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --apply-layernorm-1p \ + --attention-softmax-in-fp32 \ + --use-flash-attn \ + --transformer-impl transformer_engine \ + --use-te \ + --use-checkpoint-args \ + --normalization RMSNorm \ + --language-model-type mistral_7b \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --group-query-attention \ + --num-query-groups 8 \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 14336 \ + --num-attention-heads 32 \ + --max-position-embeddings 4096 \ + --no-masked-softmax-fusion \ + --load ${MODEL_PATH} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model mistralai/Mistral-7B-Instruct-v0.3 \ + --tokenizer-prompt-format mistral \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length 2048 \ + --out-seq-length 12 \ + --temperature 1.0 \ + --img-h 336 \ + --img-w 336 \ + --patch-dim 14 \ + --seed 153 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --task ${TASK} \ + --disable-vision-class-token \ + --num-frames ${NUM_FRAMES} \ + --ckpt-format torch +done diff --git a/examples/multimodal/train.py b/examples/multimodal/train.py index 1dc68d1173bfee00dd77d971c2b150b024acf421..72b141ddfe5f5ec5c19240af854a45beae74c6b0 100644 --- a/examples/multimodal/train.py +++ b/examples/multimodal/train.py @@ -1,300 +1,416 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -"""Pretrain or SFT multimodal.""" -import os -import sys -from functools import partial - -import torch -import yaml - -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) -) - -from dataloader_provider import train_valid_test_dataloaders_provider, is_first_or_last_stage -from model import model_provider -from multimodal_args import add_multimodal_extra_args - -from megatron.core import mpu, tensor_parallel -from megatron.core.enums import ModelType -from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, LLaVAModel -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.parallel_state import ( - get_tensor_model_parallel_rank, - get_pipeline_model_parallel_world_size, - is_pipeline_last_stage, -) -from megatron.training import get_args, get_timers, get_tokenizer, pretrain -from megatron.training.utils import is_last_rank - - -def get_batch(data_iterator): - """Generate a batch - - Note: attn_mask_type in layer_specs.py sets the attention mask. Attention mask is None here. - """ - imgs = None - tokens = None - labels = None - loss_mask = None - attention_mask = None - position_ids = None - num_tiles = None - packed_seq_params = None - - args = get_args() - - # Dataloader doesn't run on the middle stages in a pipeline parallel model. - pp_size = get_pipeline_model_parallel_world_size() - if not is_first_or_last_stage(pp_size, args.encoder_pipeline_model_parallel_size): - # Note these are all set to None above. - return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles, packed_seq_params - - # Broadcast data. - torch.cuda.nvtx.range_push("get_data") - if data_iterator is not None and get_tensor_model_parallel_rank() == 0: - data = next(data_iterator) - else: - data = None - - data_text = tensor_parallel.broadcast_data(["tokens"], data, torch.int64)["tokens"] - labels = tensor_parallel.broadcast_data(["labels"], data, torch.int64)["labels"] - - imgs = tensor_parallel.broadcast_data(["imgs"], data, torch.float32)["imgs"] - num_tiles = tensor_parallel.broadcast_data(["num_tiles"], data, torch.int32)["num_tiles"] - - cu_lengths = tensor_parallel.broadcast_data(["cu_lengths"], data, torch.int32)["cu_lengths"] - max_lengths = tensor_parallel.broadcast_data(["max_lengths"], data, torch.int32)["max_lengths"] - - # No image input (text-only sample) if the dataloader produced a dummy image. - if imgs.shape == torch.Size([1, 1]): - # FIXME: text-only data can cause a hang if the vision model is own its own pipeline rank and --freeze-ViT is enabled. - imgs = torch.tensor([], dtype=torch.float32, device=data_text.device) - num_tiles = torch.tensor([], dtype=torch.int, device=data_text.device) - - # Last pipeline parallel stage doesn't need images. - if pp_size > 1 and is_pipeline_last_stage(): - imgs = None - - # If cu_lengths and max_lengths are non-dummy, construct PackedSeqParams. Otherwise, leave it at None. - if cu_lengths.shape != torch.Size([1, 1]): - assert ( - cu_lengths.shape[0] == max_lengths.shape[0] == 1 - ), "micro-batch-size must be 1 for packing" - cu_lengths = cu_lengths[0] - max_lengths = max_lengths[0] - - packed_seq_params = PackedSeqParams( - qkv_format="thd", - cu_seqlens_q=cu_lengths, - cu_seqlens_kv=cu_lengths, - max_seqlen_q=max_lengths, - max_seqlen_kv=max_lengths, - ) - - torch.cuda.nvtx.range_pop() - - tokens_ = data_text.long() - - torch.cuda.nvtx.range_push("index tokens") - tokenizer = get_tokenizer() - text_length = tokens_.shape[1] - tokens = tokens_[:, :text_length].contiguous() - labels = labels[:, 1 : text_length + 1].contiguous() - - assert tokens.shape == labels.shape, f"tokens: {tokens.shape} != labels: {labels.shape}" - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push("get_ltor_masks_and_position_ids") - loss_mask, position_ids = get_ltor_masks_and_position_ids(tokens, labels, tokenizer.pad) - torch.cuda.nvtx.range_pop() - - return ( - tokens, - labels, - loss_mask, - attention_mask, - position_ids, - imgs, - num_tiles, - packed_seq_params, - ) - - -def get_ltor_masks_and_position_ids(input_ids, target, pad_token): - """Build masks and position id for left to right model.""" - seq_length = input_ids.shape[1] - - # Position ids. - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - - # Loss mask. - loss_mask = torch.ones(target.size(), dtype=torch.float, device=input_ids.device) - loss_mask[target == pad_token] = 0.0 # mask paddings - loss_mask[target == IGNORE_INDEX] = 0.0 # mask prompts - - return loss_mask, position_ids - - -def loss_func(loss_mask, output_tensor): - losses = output_tensor.float() - - loss_mask = loss_mask.contiguous().view(-1).float() - - total_tokens = loss_mask.sum() - total_loss = torch.sum(losses.view(-1) * loss_mask) - loss = torch.cat([total_loss.view(1), total_tokens.view(1)]) - - reporting_loss = loss.clone().detach() - torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) - - local_num_tokens = loss[1].clone().detach().to(torch.int) - - return (total_loss, local_num_tokens, {'lm loss': (reporting_loss[0], reporting_loss[1])}) - - -def forward_step(data_iterator, model: LLaVAModel): - """Forward training step. - - Args: - data_iterator (torch.utils.data.dataloader): Input data iterator - model: Multimodal model - - Returns: - output_tensor (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. - loss_func (callable): Loss function with a loss mask specified. - """ - timers = get_timers() - - # Get the batch. - timers('batch-generator', log_level=2).start() - ( - tokens, - labels, - loss_mask, - attention_mask, - position_ids, - images, - num_image_tiles, - packed_seq_params, - ) = get_batch(data_iterator) - timers('batch-generator').stop() - - output_tensor, loss_mask = model( - images, - tokens, - position_ids, - attention_mask, - labels, - loss_mask, - num_image_tiles=num_image_tiles, - packed_seq_params=packed_seq_params, - ) - - return output_tensor, partial(loss_func, loss_mask) - - -def llava_embedding_ranks(pp_ranks): - """LLava's embedding ranks consist of the decoder's first and last ranks (ie, the ViT has no embeddings). - Args: - pp_ranks: A list of global ranks that constitute a pipeline group. - """ - args = get_args() - - # encoder size is also the index to the first rank of the decoder. - epp = args.encoder_pipeline_model_parallel_size - - last_rank = pp_ranks[-1] - if len(pp_ranks) == 1 or pp_ranks[epp] == last_rank: - return [last_rank] - else: - return [pp_ranks[epp], last_rank] - - -def llava_position_embedding_ranks(pp_ranks): - """LLava's embedding ranks consist of the singular rank of the model or the decoder's first rank. - Args: - pp_ranks: A list of global ranks that constitute a pipeline group. - """ - args = get_args() - - # encoder size is also the index to the first rank of the decoder. - epp = args.encoder_pipeline_model_parallel_size - - last_rank = pp_ranks[-1] - if len(pp_ranks) == 1: - return [last_rank] - else: - return [pp_ranks[epp]] - - -def run_online_eval(model): - """Run an evaluation benchmark during training.""" - args = get_args() - - # Online evaluation config is not defined. Do nothing. - if not args.online_evaluation_config: - return [] - - from config import EvaluationConfig - from run_text_generation import generate_and_write_samples - - with open(args.online_evaluation_config, "r") as f: - config_dict = yaml.safe_load(f) - - config = EvaluationConfig(**config_dict) - - # The inference code assumes the first rank is the leader. - # Tensorboard writer is on the last rank. - # We must write to a storage space that all ranks see. - output_dir = os.path.join(args.save, "online_eval") - os.makedirs(output_dir, exist_ok=True) - config.output_path = os.path.join(output_dir, args.language_model_type) - - # The actual generation. - generate_and_write_samples(model[0].module, config, print_output=False) - - # Make sure the first rank is done writing so that the last rank can run eval. - torch.distributed.barrier() - - if not is_last_rank(): - return [] - - # Run evaluation. - if config.task == "TextVQA": - from evaluate_textvqa import textvqa_eval - - avg_acc = textvqa_eval(config.output_path) - - return [{"TextVQA accuracy": avg_acc}] - else: - raise NotImplementedError(f"online evaluation of {config.task} not implemented yet") - - -def write_online_eval_to_tensorboard(data, iteration, writer): - """Write online evaluation data to Tensorboard.""" - if not writer: - return - - for item in data: - for k, v in item.items(): - writer.add_scalar(k, v, iteration) - - -if __name__ == "__main__": - - train_valid_test_dataloaders_provider.is_distributed = True - - pretrain( - train_valid_test_dataloaders_provider, - model_provider, - ModelType.encoder_and_decoder, - forward_step, - args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - extra_args_provider=add_multimodal_extra_args, - process_non_loss_data_func=write_online_eval_to_tensorboard, - get_embedding_ranks=llava_embedding_ranks, - get_position_embedding_ranks=llava_position_embedding_ranks, - non_loss_data_func=run_online_eval, - ) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Pretrain or SFT multimodal.""" +import math +import os +import sys +from functools import partial + +import torch +import yaml + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) +) + +from dataloader_provider import train_valid_test_dataloaders_provider, is_first_or_last_stage +from model import model_provider +from multimodal_args import add_multimodal_extra_args + +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.core.models.multimodal import context_parallel +from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, LLaVAModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_tensor_model_parallel_rank, + get_pipeline_model_parallel_world_size, + is_pipeline_last_stage, +) +from megatron.training import get_args, get_timers, get_tokenizer, pretrain +from megatron.training.utils import is_last_rank, get_batch_on_this_cp_rank + + +def get_batch(data_iterator, image_token_index, img_seq_len): + """Generate a batch + + Note: attn_mask_type in layer_specs.py sets the attention mask. Attention mask is None here. + """ + imgs = None + tokens = None + labels = None + loss_mask = None + attention_mask = None + position_ids = None + num_tiles = None + packed_seq_params = None + + args = get_args() + + # Dataloader doesn't run on the middle stages in a pipeline parallel model. + pp_size = get_pipeline_model_parallel_world_size() + if not is_first_or_last_stage(pp_size, args.encoder_pipeline_model_parallel_size): + # Note these are all set to None above. + return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles, packed_seq_params + + # Broadcast data. + torch.cuda.nvtx.range_push("get_data") + if data_iterator is not None and get_tensor_model_parallel_rank() == 0: + data = next(data_iterator) + else: + data = None + + data_text = tensor_parallel.broadcast_data(["tokens"], data, torch.int64)["tokens"] + labels = tensor_parallel.broadcast_data(["labels"], data, torch.int64)["labels"] + + imgs = tensor_parallel.broadcast_data(["imgs"], data, torch.float32)["imgs"] + num_tiles = tensor_parallel.broadcast_data(["num_tiles"], data, torch.int32)["num_tiles"] + + cu_lengths = tensor_parallel.broadcast_data(["cu_lengths"], data, torch.int32)["cu_lengths"] + max_lengths = tensor_parallel.broadcast_data(["max_lengths"], data, torch.int32)["max_lengths"] + + # No image input (text-only sample) if the dataloader returned a size 1 image. + if imgs.shape == torch.Size([1, 1]): + # FSDP can hang with text-only samples. A workaround is to run a valid dummy image through the vision + # model and then add image embeddings with a zero multiplier. + if args.use_torch_fsdp2: + imgs = torch.zeros((1, 3, args.img_h, args.img_w), dtype=torch.float32, device=data_text.device) + num_tiles = torch.tensor([], dtype=torch.int, device=data_text.device) + else: + # Similar workaround is not needed without FSDP and we can use an empty image. + # FIXME: text-only data can cause still cause a hang in the special case where + # the vision model is own its own pipeline rank and --freeze-ViT is enabled. + imgs = torch.tensor([], dtype=torch.float32, device=data_text.device) + num_tiles = torch.tensor([], dtype=torch.int, device=data_text.device) + + # Last pipeline parallel stage doesn't need images. + if pp_size > 1 and is_pipeline_last_stage(): + imgs = None + + # If cu_lengths and max_lengths are non-dummy, construct PackedSeqParams. Otherwise, leave it at None. + if cu_lengths.shape != torch.Size([1, 1]): + assert ( + cu_lengths.shape[0] == max_lengths.shape[0] == 1 + ), "micro-batch-size must be 1 for packing" + cu_lengths = cu_lengths[0] + max_lengths = max_lengths[0] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_lengths, + cu_seqlens_kv=cu_lengths, + max_seqlen_q=max_lengths, + max_seqlen_kv=max_lengths, + ) + + torch.cuda.nvtx.range_pop() + + tokens_ = data_text.long() + + torch.cuda.nvtx.range_push("index tokens") + tokenizer = get_tokenizer() + text_length = tokens_.shape[1] + tokens = tokens_[:, :text_length].contiguous() + labels = labels[:, 1 : text_length + 1].contiguous() + + assert tokens.shape == labels.shape, f"tokens: {tokens.shape} != labels: {labels.shape}" + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("get_ltor_masks_and_position_ids") + loss_mask, position_ids = get_ltor_masks_and_position_ids(tokens, labels, tokenizer.pad) + torch.cuda.nvtx.range_pop() + + # If context parallel is enabled, must shard inputs to CP ranks. + if args.context_parallel_size > 1 or args.sequence_parallel: + assert tokens.shape[0], "micro-batch-size > 1 not supported yet with CP" + + num_image_tokens = torch.sum(tokens == image_token_index).item() + num_image_embeddings = num_image_tokens * img_seq_len - num_image_tokens + seq_len = text_length + num_image_embeddings + + # CP expects sequence length is divisible by CP size so apply padding. + mp_padding_needed = context_parallel.get_padding( + seq_len, args.context_parallel_size, + args.tensor_model_parallel_size, args.sequence_parallel, + ) + tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed)) for item in (tokens, position_ids, labels, loss_mask)] + + # Get PackedSeqParams that indicate the amount of padding for TransformerEngine. + packed_seq_params = context_parallel.get_packed_seq_params(tokens, num_image_embeddings, mp_padding_needed, args.context_parallel_size, True) + + return ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + imgs, + num_tiles, + packed_seq_params, + ) + + +def get_ltor_masks_and_position_ids(input_ids, target, pad_token): + """Build masks and position id for left to right model.""" + seq_length = input_ids.shape[1] + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + # Loss mask. + loss_mask = torch.ones(target.size(), dtype=torch.float, device=input_ids.device) + loss_mask[target == pad_token] = 0.0 # mask paddings + loss_mask[target == IGNORE_INDEX] = 0.0 # mask prompts + + return loss_mask, position_ids + + +def get_mask_start_and_end_idx(arr): + """ + Returns a list of tuples holding the start and end index in arr of the non-zeros contiguuous + sub arrays. + + For instance, if arr = [0, 1, 0, 0, 1, 1] + get_mask_start_and_end_idx(arr) = [(1, 1), (4, 5)] + such that arr[1:1+1] = [1] and arr[4:5+1] = [1, 1] + """ + mask = (arr != 0) + + mask_int = mask.int() + + diff = mask_int[1:] - mask_int[:-1] + start_indices = (diff == 1).nonzero(as_tuple=False).flatten() + 1 + end_indices = (diff == -1).nonzero(as_tuple=False).flatten() + if len(mask)==0: return [] + if mask[0]: + start_indices = torch.cat((torch.tensor([0], device=arr.device), start_indices)) + if mask[-1]: + end_indices = torch.cat((end_indices, torch.tensor([len(arr) - 1], device=arr.device))) + sequences = list(zip(start_indices.tolist(), end_indices.tolist())) + return sequences + + +def scaled_loss_func(loss_mask, output_tensor): + """ + Scaled loss function + + Scale the loss for each conversation turn using the formula: + + 1 / sum_j[ sqrt(length(loss_turn_j)) ] * sum_i[ sum(loss_turn_i) / sqrt(length(loss_turn_i)) ] + + Where we use the loss mask to infer the start / end of the conversation turns. + """ + losses = output_tensor.float() + + loss_list = [] + num_valid_labels_list = [] + for idx in range(losses.shape[0]): + loss_this_sample = losses[idx] + turn_start_end_list = get_mask_start_and_end_idx(loss_mask[idx]) + for turn_start, turn_end in turn_start_end_list: + # compute loss for each turn + loss_this_turn = loss_this_sample[turn_start:turn_end+1].sum() + assert (1 - loss_mask)[idx][turn_start:turn_end+1].sum() < 1.0 + num_valid_labels_this_turn = turn_end - turn_start + 1 + loss_this_turn = loss_this_turn / num_valid_labels_this_turn + loss_list.append(loss_this_turn) + # append num of valid labels for each turn + num_valid_labels_list.append(num_valid_labels_this_turn) + base_num = sum([math.sqrt(each) for each in num_valid_labels_list]) + for idx in range(len(loss_list)): + # normalize loss for each turn + loss_list[idx] = loss_list[idx] * math.sqrt(num_valid_labels_list[idx]) / base_num + + total_loss = torch.stack(loss_list).sum() + total_tokens = torch.ones_like(total_loss) + + loss = torch.cat([total_loss.view(1), total_tokens.view(1)]) + + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + + return ( + total_loss, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def loss_func(loss_mask, output_tensor): + args = get_args() + + losses = output_tensor.float() + + loss_mask = loss_mask.contiguous().view(-1).float() + + total_tokens = loss_mask.sum() + total_loss = torch.sum(losses.view(-1) * loss_mask) + loss = torch.cat([total_loss.view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + + # We multiply by context parallel size because later there will be a divide by CP(+DP) size. + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])} + ) + + +def forward_step(data_iterator, model: LLaVAModel): + """Forward training step. + + Args: + data_iterator (torch.utils.data.dataloader): Input data iterator + model: Multimodal model + + Returns: + output_tensor (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. + loss_func (callable): Loss function with a loss mask specified. + """ + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + images, + num_image_tiles, + packed_seq_params, + ) = get_batch(data_iterator, model.module.module.image_token_index, model.module.module.img_seq_len) + timers('batch-generator').stop() + + output_tensor, loss_mask = model( + images, + tokens, + position_ids, + attention_mask, + labels, + loss_mask, + num_image_tiles=num_image_tiles, + packed_seq_params=packed_seq_params, + ) + args = get_args() + if args.use_loss_scaling: + loss_function = partial(scaled_loss_func, loss_mask) + else: + loss_function = partial(loss_func, loss_mask) + + return output_tensor, loss_function + + +def llava_embedding_ranks(pp_ranks): + """LLava's embedding ranks consist of the decoder's first and last ranks (ie, the ViT has no embeddings). + Args: + pp_ranks: A list of global ranks that constitute a pipeline group. + """ + args = get_args() + + # encoder size is also the index to the first rank of the decoder. + epp = args.encoder_pipeline_model_parallel_size + + last_rank = pp_ranks[-1] + if len(pp_ranks) == 1 or pp_ranks[epp] == last_rank: + return [last_rank] + else: + return [pp_ranks[epp], last_rank] + + +def llava_position_embedding_ranks(pp_ranks): + """LLava's embedding ranks consist of the singular rank of the model or the decoder's first rank. + Args: + pp_ranks: A list of global ranks that constitute a pipeline group. + """ + args = get_args() + + # encoder size is also the index to the first rank of the decoder. + epp = args.encoder_pipeline_model_parallel_size + + last_rank = pp_ranks[-1] + if len(pp_ranks) == 1: + return [last_rank] + else: + return [pp_ranks[epp]] + + +def run_online_eval(model): + """Run an evaluation benchmark during training.""" + args = get_args() + + # Online evaluation config is not defined. Do nothing. + if not args.online_evaluation_config: + return [] + + from config import EvaluationConfig + from run_text_generation import generate_and_write_samples + + with open(args.online_evaluation_config, "r") as f: + config_dict = yaml.safe_load(f) + + config = EvaluationConfig(**config_dict) + + # The inference code assumes the first rank is the leader. + # Tensorboard writer is on the last rank. + # We must write to a storage space that all ranks see. + output_dir = os.path.join(args.save, "online_eval") + os.makedirs(output_dir, exist_ok=True) + config.output_path = os.path.join(output_dir, args.language_model_type) + + # The actual generation. + generate_and_write_samples(model[0].module, config, print_output=False) + + # Make sure the first rank is done writing so that the last rank can run eval. + torch.distributed.barrier() + + if not is_last_rank(): + return [] + + # Run evaluation. + if config.task == "TextVQA": + from evaluate_textvqa import textvqa_eval + + avg_acc = textvqa_eval(config.output_path) + + return [{"TextVQA accuracy": avg_acc}] + else: + raise NotImplementedError(f"online evaluation of {config.task} not implemented yet") + + +def write_online_eval_to_tensorboard(data, iteration, writer): + """Write online evaluation data to Tensorboard.""" + if not writer: + return + + for item in data: + for k, v in item.items(): + writer.add_scalar(k, v, iteration) + + +if __name__ == "__main__": + + train_valid_test_dataloaders_provider.is_distributed = True + + pretrain( + train_valid_test_dataloaders_provider, + model_provider, + ModelType.encoder_and_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + extra_args_provider=add_multimodal_extra_args, + process_non_loss_data_func=write_online_eval_to_tensorboard, + get_embedding_ranks=llava_embedding_ranks, + get_position_embedding_ranks=llava_position_embedding_ranks, + non_loss_data_func=run_online_eval, + ) diff --git a/examples/retro/preprocess_data.sh b/examples/retro/preprocess_data.sh old mode 100644 new mode 100755 diff --git a/examples/retro/train_retro_2b_distributed.sh b/examples/retro/train_retro_2b_distributed.sh old mode 100644 new mode 100755 diff --git a/examples/t5/train_t5_220m_distributed.sh b/examples/t5/train_t5_220m_distributed.sh old mode 100644 new mode 100755 index 62e6f9db4bd1c3a2d73d455e641708239e1add82..8793a992f938ebacddab664d10352394fc247988 --- a/examples/t5/train_t5_220m_distributed.sh +++ b/examples/t5/train_t5_220m_distributed.sh @@ -1,78 +1,78 @@ -#!/bin/bash - -# Runs the "220M" parameter model - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NUM_NODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) - -CHECKPOINT_PATH=$1 # -TENSORBOARD_DIR=$2 # -VOCAB_FILE=$3 #/bert-large-cased-vocab.txt -DATA_PATH=$4 #_text_document - -DISTRIBUTED_ARGS=" - --nproc_per_node $GPUS_PER_NODE \ - --nnodes $NUM_NODES \ - --node_rank $NODE_RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT -" - -T5_ARGS=" - --encoder-num-layers 12 \ - --decoder-num-layers 12 \ - --hidden-size 768 \ - --num-attention-heads 12 \ - --kv-channels 64 \ - --ffn-hidden-size 3072 \ - --encoder-seq-length 512 \ - --decoder-seq-length 128 \ - --max-position-embeddings 512 \ - --micro-batch-size 64 \ - --global-batch-size 512 \ - --lr 0.0001 \ - --train-iters 1000000 \ - --lr-decay-iters 1000000 \ - --lr-decay-style linear \ - --min-lr 0.00001 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --bf16 \ - --vocab-extra-ids 100 \ - --init-method-std 0.015 \ - --transformer-impl transformer_engine \ - --tensor-model-parallel-size 1 \ - --pipeline-model-parallel-size 1 \ - --attention-backend auto \ -" - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --tokenizer-type BertWordPieceCase \ - --split 99982,9,9 \ -" - -OUTPUT_ARGS=" - --log-interval 100 \ - --tensorboard-dir ${TENSORBOARD_DIR} \ - --save-interval 500 \ - --eval-interval 1000 \ - --eval-iters 10 -" - -torchrun $DISTRIBUTED_ARGS pretrain_t5.py \ - $T5_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --distributed-backend nccl \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH \ +#!/bin/bash + +# Runs the "220M" parameter model + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NUM_NODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +CHECKPOINT_PATH=$1 # +TENSORBOARD_DIR=$2 # +VOCAB_FILE=$3 #/bert-large-cased-vocab.txt +DATA_PATH=$4 #_text_document + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NUM_NODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +T5_ARGS=" + --encoder-num-layers 12 \ + --decoder-num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --kv-channels 64 \ + --ffn-hidden-size 3072 \ + --encoder-seq-length 512 \ + --decoder-seq-length 128 \ + --max-position-embeddings 512 \ + --micro-batch-size 64 \ + --global-batch-size 512 \ + --lr 0.0001 \ + --train-iters 1000000 \ + --lr-decay-iters 1000000 \ + --lr-decay-style linear \ + --min-lr 0.00001 \ + --weight-decay 1e-2 \ + --lr-warmup-fraction .01 \ + --clip-grad 1.0 \ + --bf16 \ + --vocab-extra-ids 100 \ + --init-method-std 0.015 \ + --transformer-impl transformer_engine \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --attention-backend auto \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --tokenizer-type BertWordPieceCase \ + --split 99982,9,9 \ +" + +OUTPUT_ARGS=" + --log-interval 100 \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --save-interval 500 \ + --eval-interval 1000 \ + --eval-iters 10 +" + +torchrun $DISTRIBUTED_ARGS pretrain_t5.py \ + $T5_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --distributed-backend nccl \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ diff --git a/gptnodes b/gptnodes deleted file mode 100644 index 523cbf3c0841a619bd0986e2ae01d80a0802c891..0000000000000000000000000000000000000000 --- a/gptnodes +++ /dev/null @@ -1,32 +0,0 @@ -node002 slots=8 -node003 slots=8 -node004 slots=8 -node005 slots=8 -node006 slots=8 -node020 slots=8 -node021 slots=8 -node022 slots=8 -node033 slots=8 -node034 slots=8 -node035 slots=8 -node036 slots=8 -node037 slots=8 -node038 slots=8 -node039 slots=8 -node040 slots=8 -node041 slots=8 -node042 slots=8 -node043 slots=8 -node044 slots=8 -node045 slots=8 -node046 slots=8 -node047 slots=8 -node048 slots=8 -node056 slots=8 -node057 slots=8 -node058 slots=8 -node059 slots=8 -node060 slots=8 -node061 slots=8 -node062 slots=8 -node063 slots=8 diff --git a/megatron/core/datasets/blended_dataset.py b/megatron/core/datasets/blended_dataset.py index be0b7a4a08a8df931a2392c1a518e423cc4a3371..6b027fa61b4f4d6b6315dc602fb4e4c576c0eb77 100644 --- a/megatron/core/datasets/blended_dataset.py +++ b/megatron/core/datasets/blended_dataset.py @@ -1,201 +1,201 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import hashlib -import json -import logging -import os -import time -from collections import OrderedDict -from typing import Dict, List, Optional, Tuple, Union - -import numpy -import torch - -from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig -from megatron.core.datasets.megatron_dataset import MegatronDataset -from megatron.core.datasets.utils import normalize -from megatron.core.utils import log_single_rank - -logger = logging.getLogger(__name__) - -_VERBOSE = False - - -class BlendedDataset(torch.utils.data.Dataset): - """Conjugating class for a set of MegatronDataset instances - - Args: - datasets (List[MegatronDataset]): The MegatronDataset instances to blend - - weights (List[Union[int, float]]): The weights that determine the dataset blend ratios - - size (Optional[int]): The number of samples to draw from the blend. If None, for each dataset index idx draw exactly weights[idx] samples from datasets[idx]. - - config (BlendedMegatronDatasetConfig): The config - - Raises: - RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization - """ - - def __init__( - self, - datasets: List[MegatronDataset], - weights: List[Union[int, float]], - size: Optional[int], - config: BlendedMegatronDatasetConfig, - ) -> None: - assert len(datasets) == len(weights) - assert len(datasets) < 32767 - assert all(map(lambda _: type(_) == type(datasets[0]), datasets)) - assert all(map(lambda _: _.index_split == datasets[0].index_split, datasets)) - assert all(map(lambda _: _ > 0, weights)) - assert all(map(lambda _: type(_) == type(weights[0]), weights)) - if size is None and isinstance(weights[0], float): - assert all(map(lambda _: _ == int(_), weights)) - - # Alert user to unnecessary blending - if len(datasets) == 1: - log_single_rank( - logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset" - ) - - if size is not None: - weights = normalize(weights) - - self.datasets = datasets - self.split = self.datasets[0].index_split - self.weights = weights - self.size = size - self.config = config - - unique_identifiers = OrderedDict() - unique_identifiers["class"] = type(self).__name__ - unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets] - unique_identifiers["split"] = self.split.name - unique_identifiers["weights"] = self.weights - unique_identifiers["size"] = self.size - unique_identifiers["renormalize_blend_weights"] = self.config.renormalize_blend_weights - - self.unique_description = json.dumps( - unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers - ) - self.unique_description_hash = hashlib.md5( - self.unique_description.encode("utf-8") - ).hexdigest() - - self.built_anew_on_cache_miss = False - - self.dataset_index, self.dataset_sample_index = self._build_indices() - - def __len__(self) -> int: - return self.dataset_index.shape[0] - - def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: - dataset_id = self.dataset_index[idx] - dataset_sample_id = self.dataset_sample_index[idx] - return {"dataset_id": dataset_id, **self.datasets[dataset_id][dataset_sample_id]} - - def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: - """Build and optionally cache the dataset index and the dataset sample index - - The dataset index is a 1-D mapping which determines the dataset to query. The dataset - sample index is a 1-D mapping which determines the sample to request from the queried - dataset. - - Returns: - Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index - """ - path_to_cache = self.config.path_to_cache - - if path_to_cache: - get_path_to = lambda suffix: os.path.join( - path_to_cache, - f"{self.unique_description_hash}-{type(self).__name__}-{self.split.name}-{suffix}", - ) - path_to_description = get_path_to("description.txt") - path_to_dataset_index = get_path_to("dataset_index.npy") - path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy") - cache_hit = all( - map( - os.path.isfile, - [path_to_description, path_to_dataset_index, path_to_dataset_sample_index], - ) - ) - else: - cache_hit = False - - if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0): - log_single_rank( - logger, logging.INFO, f"Build and save the {type(self).__name__} indices" - ) - self.built_anew_on_cache_miss = True - - # Build the dataset and dataset sample indexes - log_single_rank( - logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes" - ) - t_beg = time.time() - from megatron.core.datasets import helpers - - if self.size is not None: - dataset_index = numpy.zeros(self.size, dtype=numpy.int16) - dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64) - helpers.build_blending_indices( - dataset_index, - dataset_sample_index, - self.weights, - len(self.datasets), - self.size, - _VERBOSE, - ) - else: - size = sum(self.weights) - dataset_index = numpy.zeros(size, dtype=numpy.int16) - dataset_sample_index = numpy.zeros(size, dtype=numpy.int64) - helpers.build_exhaustive_blending_indices( - dataset_index, dataset_sample_index, self.weights, len(self.datasets) - ) - - if path_to_cache: - os.makedirs(path_to_cache, exist_ok=True) - # Write the description - with open(path_to_description, "wt") as writer: - writer.write(self.unique_description) - # Save the indexes - numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True) - numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True) - else: - log_single_rank( - logger, - logging.WARNING, - f"Unable to save the {type(self).__name__} indexes because path_to_cache is None", - ) - - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - return dataset_index, dataset_sample_index - - log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices") - - log_single_rank( - logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}" - ) - t_beg = time.time() - dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode='r') - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - log_single_rank( - logger, - logging.INFO, - f"\tLoad the dataset sample index from {path_to_dataset_sample_index}", - ) - t_beg = time.time() - dataset_sample_index = numpy.load( - path_to_dataset_sample_index, allow_pickle=True, mmap_mode='r' - ) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - return dataset_index, dataset_sample_index +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import hashlib +import json +import logging +import os +import time +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union + +import numpy +import torch + +from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig +from megatron.core.datasets.megatron_dataset import MegatronDataset +from megatron.core.datasets.utils import normalize +from megatron.core.utils import log_single_rank + +logger = logging.getLogger(__name__) + +_VERBOSE = False + + +class BlendedDataset(torch.utils.data.Dataset): + """Conjugating class for a set of MegatronDataset instances + + Args: + datasets (List[MegatronDataset]): The MegatronDataset instances to blend + + weights (List[Union[int, float]]): The weights that determine the dataset blend ratios + + size (Optional[int]): The number of samples to draw from the blend. If None, for each + dataset index idx draw exactly weights[idx] samples from datasets[idx]. + + config (BlendedMegatronDatasetConfig): The config + + Raises: + RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization + """ + + def __init__( + self, + datasets: List[MegatronDataset], + weights: List[Union[int, float]], + size: Optional[int], + config: BlendedMegatronDatasetConfig, + ) -> None: + assert len(datasets) == len(weights) + assert len(datasets) < 32767 + assert all(map(lambda _: type(_) == type(datasets[0]), datasets)) + assert all(map(lambda _: _.index_split == datasets[0].index_split, datasets)) + assert all(map(lambda _: _ > 0, weights)) + assert all(map(lambda _: type(_) == type(weights[0]), weights)) + if size is None and isinstance(weights[0], float): + assert all(map(lambda _: _ == int(_), weights)) + + # Alert user to unnecessary blending + if len(datasets) == 1: + log_single_rank( + logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset" + ) + + if size is not None: + weights = normalize(weights) + + self.datasets = datasets + self.split = self.datasets[0].index_split + self.weights = weights + self.size = size + self.config = config + + unique_identifiers = OrderedDict() + unique_identifiers["class"] = type(self).__name__ + unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets] + unique_identifiers["split"] = self.split.name + unique_identifiers["weights"] = self.weights + unique_identifiers["size"] = self.size + + self.unique_description = json.dumps( + unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers + ) + self.unique_description_hash = hashlib.md5( + self.unique_description.encode("utf-8") + ).hexdigest() + + self.built_anew_on_cache_miss = False + + self.dataset_index, self.dataset_sample_index = self._build_indices() + + def __len__(self) -> int: + return self.dataset_index.shape[0] + + def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: + dataset_id = self.dataset_index[idx] + dataset_sample_id = self.dataset_sample_index[idx] + return {"dataset_id": dataset_id, **self.datasets[dataset_id][dataset_sample_id]} + + def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: + """Build and optionally cache the dataset index and the dataset sample index + + The dataset index is a 1-D mapping which determines the dataset to query. The dataset + sample index is a 1-D mapping which determines the sample to request from the queried + dataset. + + Returns: + Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index + """ + path_to_cache = self.config.path_to_cache + + if path_to_cache: + get_path_to = lambda suffix: os.path.join( + path_to_cache, + f"{self.unique_description_hash}-{type(self).__name__}-{self.split.name}-{suffix}", + ) + path_to_description = get_path_to("description.txt") + path_to_dataset_index = get_path_to("dataset_index.npy") + path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy") + cache_hit = all( + map( + os.path.isfile, + [path_to_description, path_to_dataset_index, path_to_dataset_sample_index], + ) + ) + else: + cache_hit = False + + if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0): + log_single_rank( + logger, logging.INFO, f"Build and save the {type(self).__name__} indices" + ) + self.built_anew_on_cache_miss = True + + # Build the dataset and dataset sample indexes + log_single_rank( + logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes" + ) + t_beg = time.time() + from megatron.core.datasets import helpers + + if self.size is not None: + dataset_index = numpy.zeros(self.size, dtype=numpy.int16) + dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64) + helpers.build_blending_indices( + dataset_index, + dataset_sample_index, + self.weights, + len(self.datasets), + self.size, + _VERBOSE, + ) + else: + size = sum(self.weights) + dataset_index = numpy.zeros(size, dtype=numpy.int16) + dataset_sample_index = numpy.zeros(size, dtype=numpy.int64) + helpers.build_exhaustive_blending_indices( + dataset_index, dataset_sample_index, self.weights, len(self.datasets) + ) + + if path_to_cache: + os.makedirs(path_to_cache, exist_ok=True) + # Write the description + with open(path_to_description, "wt") as writer: + writer.write(self.unique_description) + # Save the indexes + numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True) + numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True) + else: + log_single_rank( + logger, + logging.WARNING, + f"Cannot save the {type(self).__name__} indexes because path_to_cache is None", + ) + + t_end = time.time() + log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") + + return dataset_index, dataset_sample_index + + log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices") + + log_single_rank( + logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}" + ) + t_beg = time.time() + dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode='r') + t_end = time.time() + log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") + + log_single_rank( + logger, + logging.INFO, + f"\tLoad the dataset sample index from {path_to_dataset_sample_index}", + ) + t_beg = time.time() + dataset_sample_index = numpy.load( + path_to_dataset_sample_index, allow_pickle=True, mmap_mode='r' + ) + t_end = time.time() + log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") + + return dataset_index, dataset_sample_index diff --git a/megatron/core/datasets/blended_megatron_dataset_builder.py b/megatron/core/datasets/blended_megatron_dataset_builder.py index c9cf4abf63c1451c253ebdeaae5d70e8d12dfc49..e69e0a66b65bc4e41ae613346ee3dc8491f07922 100644 --- a/megatron/core/datasets/blended_megatron_dataset_builder.py +++ b/megatron/core/datasets/blended_megatron_dataset_builder.py @@ -1,528 +1,579 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import logging -import math -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Iterable, List, Optional, Type, Union - -import numpy -import torch - -from megatron.core.datasets.blended_dataset import BlendedDataset -from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig -from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset -from megatron.core.datasets.utils import Split, normalize -from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank -from megatron.core.utils import log_single_rank - -logger = logging.getLogger(__name__) - -MidLevelDataset = MegatronDataset - -TopLevelDataset = Union[BlendedDataset, MidLevelDataset] - -DistributedDataset = Union[ - TopLevelDataset, MidLevelDataset, LowLevelDataset, torch.utils.data.Dataset -] - - -class BlendedMegatronDatasetBuilder(object): - """Builder class for the BlendedDataset and MegatronDataset classes - - Args: - cls (Type[MegatronDataset]): The class to instantiate, must inherit from MegatronDataset - - sizes (List[Optional[int]]): The minimum total number of samples to draw, or None, per split - - is_built_on_rank (Callable): A callable which returns True if the dataset should be built on the current rank and False otherwise. It should be Megatron Core parallelism aware i.e. global rank, local group rank, and virtual rank may inform its return value. - - config (BlendedMegatronDatasetConfig): The config object which informs dataset creation - """ - - def __init__( - self, - cls: Type[MidLevelDataset], - sizes: List[int], - is_built_on_rank: Callable, - config: BlendedMegatronDatasetConfig, - ): - self.cls = cls - self.sizes = sizes - self.is_built_on_rank = is_built_on_rank - self.config = config - - log_single_rank( - logger, - logging.INFO, - f"Building dataset splits with cls={cls.__name__}, sizes={self.sizes}, and config={self.config}", - ) - - if not self.config.mock: - for split in Split: - size_is_none = self.sizes[split.value] is None - if self.config.blend_per_split is None: - weights_are_none = self.config.blend[1] is None - else: - if self.config.blend_per_split[split.value] is None: - continue - weights_are_none = self.config.blend_per_split[split.value][1] is None - if size_is_none: - assert ( - weights_are_none - ), f"size_is_none => weights_are_none fails for {split.name} split" - - if torch.distributed.is_initialized(): - gb_rank = torch.distributed.get_rank() - vp_rank = get_virtual_pipeline_model_parallel_rank() - if gb_rank == 0 and (vp_rank == 0 or vp_rank is None): - assert ( - self.is_built_on_rank() - ), "is_built_on_rank must return True when global rank = 0 and vp rank = 0" - - def build(self) -> List[Optional[TopLevelDataset]]: - """Build all dataset splits according to the provided blend(s) - - This method is distributed-aware and must be called on all ranks. - - The dataset splits returned can vary according to the config. Supply config.blend and - config.split to build BlendedDataset and/or MegatronDataset splits from the same - distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset - splits from separate distributions. In either case, for each split, handle the following - cases: - - (1) The split is None - - do nothing - - (2) The split has one contributing dataset, and... - - (a) 'size' is not None - - Build a mid-level dataset with low-level dataset sampling in proportion to the size - - (b) 'size' is None - - Build mid-level datasets with no excess low-level dataset sampling - - (3) The split has multiple contributing datasets, and... - - (a) 'weights' is not None and 'size' is not None - - Build mid-level datasets with low-level dataset sampling in proportion to their weights and the size - - Build a top-level dataset of length marginally greater than 'size' with mid-level dataset sampling in proportion to their weights and the size - - (b) 'weights' is not None and 'size' is None - - Error - - (c) 'weights' is None and 'size' is not None - - Build mid-level datasets with no excess low-level dataset sampling - - Build a top-level dataset of length 'size' with mid-level dataset sampling in proportion to their lengths and the size - - - The 'size' of the top-level dataset is capped at the sum of the mid-level dataset lengths - - (d) 'weights' is None and 'size' is None - - Build mid-level datasets with no excess low-level dataset sampling - - Build a top-level dataset with no excess mid-level dataset sampling - - Returns: - List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per split - """ - datasets = self._build_blended_dataset_splits() - - for dataset in datasets: - if dataset is not None and len(dataset) > 0: - if isinstance(dataset, BlendedDataset): - if dataset.built_anew_on_cache_miss or any( - x.built_anew_on_cache_miss for x in dataset.datasets - ): - log_single_rank( - logger, - logging.INFO, - f"Verifying NumPy indices for {type(dataset).__name__} {dataset.split.name} split", - ) - else: - log_single_rank( - logger, - logging.INFO, - f"NumPy indices for {type(dataset).__name__} {dataset.split.name} split are fully cached, skipping verification", - ) - continue - # Check blend size - assert dataset.size is None or dataset.size == dataset.dataset_index.shape[0] - # Check blend access of mid-level datasets - _, sizes = numpy.unique(dataset.dataset_index, return_counts=True) - for i, dataset_and_size in enumerate(zip(dataset.datasets, sizes)): - if len(dataset_and_size[0]) < dataset_and_size[1]: - raise IndexError( - f"The {dataset.split.name} blend oversamples (N = {dataset_and_size[1]}) {type(dataset_and_size[0]).__name__} {i} (len = {len(dataset_and_size[0])}). " - f"Set renormalize_blend_weights to True and re-run. File an issue if the problem is not resolved." - ) - - return datasets - - def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]: - """Build all dataset splits according to the provided blend(s) - - See the BlendedMegatronDatasetBuilder.build alias for more information. - - Returns: - List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per split - """ - ## - # Return fake "mock" datasets - ## - if self.config.mock: - split = self.config.split_matrix - try: - return self._build_megatron_dataset_splits(None, split, self.sizes) - except Exception as error: - raise Exception( - f"{self.cls.__name__} failed to build as a mock data generator" - ) from error - - ## - # All splits come from the same distribution - ## - elif self.config.blend: - prefixes, weights = self.config.blend - if weights is not None: - weights = normalize(weights) - - split = self.config.split_matrix - - # Blend consists of a single prefix - if len(prefixes) == 1 and weights is None: - return self._build_megatron_dataset_splits(prefixes[0], split, self.sizes) - - # Build the mid-level datasets - if weights is None: - sizes_per_dataset = [[None for split in Split] for prefix in prefixes] - else: - sizes_per_dataset = _get_size_per_split_per_dataset(weights, self.sizes) - - # build each dataset in parallel - megatron_datasets = self._build_megatron_datasets_parallel( - prefixes, split, sizes_per_dataset - ) - - # Build the top-level datasets - blended_datasets = [None] * len(Split) - for i in range(len(Split)): - if split[i] is not None: - weights_i = weights - if weights_i is not None and self.sizes[i] is not None: - size_per_dataset = list(zip(*sizes_per_dataset))[i] - size_i = sum(size_per_dataset) - if self.config.renormalize_blend_weights: - weights_i = list(map(lambda _size: _size / size_i, size_per_dataset)) - elif weights_i is None: - try: - weights_i = [ - len(megatron_dataset) for megatron_dataset in megatron_datasets[i] - ] - except TypeError: - weights_i = [0 for _ in prefixes] - if self.sizes[i] is not None: - size_i = min(self.sizes[i], sum(weights_i)) - else: - size_i = None # => the size will be sum(weights_i) - else: - raise RuntimeError - blended_datasets[i] = self.build_generic_dataset( - BlendedDataset, - self.is_built_on_rank, - True, # synchronize_ranks, default behavior to build on rank-0 first - megatron_datasets[i], - weights_i, - size_i, - self.config, - ) - - return blended_datasets - - ## - # Each split comes from a separate distribution - ## - else: - blended_datasets = [None] * len(Split) - for i in range(len(Split)): - split_spoof = [None] * len(Split) - split_spoof[i] = (0.0, 1.0) - sizes_spoof = [0] * len(Split) - sizes_spoof[i] = self.sizes[i] - - # Blend is provided for the split - blend = self.config.blend_per_split[i] - if blend is not None: - prefixes, weights = blend - if weights is not None: - weights = normalize(weights) - - # Blend consists of a sigle prefix - if len(prefixes) == 1: - blended_datasets[i] = self._build_megatron_dataset_splits( - prefixes[0], split_spoof, sizes_spoof - )[i] - continue - - # Build mid-level datasets - if weights is None: - sizes_per_dataset = [[None for split in Split] for prefix in prefixes] - else: - sizes_per_dataset = _get_size_per_split_per_dataset(weights, sizes_spoof) - - # build each dataset in parallel - megatron_datasets = self._build_megatron_datasets_parallel( - prefixes, split_spoof, sizes_per_dataset - )[i] - - # Build top-level dataset - if weights is not None and self.sizes[i] is not None: - size_per_dataset = list(zip(*sizes_per_dataset))[i] - size = sum(size_per_dataset) - if self.config.renormalize_blend_weights: - weights = list(map(lambda _size: _size / size, size_per_dataset)) - elif weights is None: - try: - weights = [ - len(megatron_dataset) for megatron_dataset in megatron_datasets - ] - except TypeError: - weights = [0 for _ in prefixes] - if self.sizes[i] is not None: - size = min(self.sizes[i], sum(weights)) - else: - size = None # => the size will be sum(weights) - else: - raise RuntimeError - blended_datasets[i] = self.build_generic_dataset( - BlendedDataset, - self.is_built_on_rank, - True, # synchronize_ranks, default behavior to build on rank-0 first - megatron_datasets, - weights, - size, - self.config, - ) - - return blended_datasets - - def _build_megatron_datasets_parallel( - self, prefixes: List[str], split: List[float], sizes_per_dataset: List[List[int]] - ) -> List[List[Optional[MegatronDataset]]]: - """Build the megatron datasets for a list of prefixes in parallel - - Args: - prefixes (List[str]): The list of prefix strings - - split (List[float]): The dataset split ratios (must sum to 1.00) - - sizes_per_dataset (List[List[int]]): The number of samples to request - per MegatronDataset per spilt - - Returns: - List[List[Optional[MegatronDataset]]]: For each split, have a list of - MegatronDataset per prefix - """ - - # Helper function to wrap the threading logic - def _threading_helper( - megatron_datasets: List[List[Optional[MegatronDataset]]], - num_workers: int, - prefixes: List[str], - split: List[float], - sizes_per_dataset: List[List[int]], - ) -> None: - with ThreadPoolExecutor(max_workers=num_workers) as executor: - all_futures = [] - for i in range(len(prefixes)): - all_futures.append( - executor.submit( - self._build_megatron_dataset_splits, - prefixes[i], - split, - sizes_per_dataset[i], - False, # synchronize_ranks, barrier is called in this function - ) - ) - for future in all_futures: - try: - megatron_datasets_split = future.result() - for j in range(len(megatron_datasets_split)): - megatron_datasets[j].append(megatron_datasets_split[j]) - except Exception as err: - raise err - - megatron_datasets = [[] for _ in range(len(Split))] - num_dataset_builder_threads = self.config.num_dataset_builder_threads - - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - # First, build on rank 0 - if rank == 0: - num_workers = num_dataset_builder_threads - if num_workers > 1: - # since only rank 0 is running, scale up the thread count - # but not too much to avoid overloading storage on miss path. - # if user set num_dataset_builder_threads to 1, - # i.e. meant for serial build, do not scale up. - num_workers *= min(2, max(1, torch.cuda.device_count())) - _threading_helper( - megatron_datasets, num_workers, prefixes, split, sizes_per_dataset - ) - - torch.distributed.barrier() - - # Then, build on other ranks; guaranteed to be data_cache hit - if rank != 0: - _threading_helper( - megatron_datasets, - num_dataset_builder_threads, - prefixes, - split, - sizes_per_dataset, - ) - else: - _threading_helper( - megatron_datasets, num_dataset_builder_threads, prefixes, split, sizes_per_dataset - ) - - return megatron_datasets - - def _build_megatron_dataset_splits( - self, - dataset_path: Optional[str], - split: List[float], - sizes: List[int], - synchronize_ranks: bool = True, - ) -> List[Optional[MidLevelDataset]]: - """Build each MidLevelDataset split from a single LowLevelDataset - - Args: - dataset_path (Optional[str]): The path on disk which defines the underlying LowLevelDataset, or None for mock dataset classes - - split (List[Tuple[float, float]]): The dataset split matrix - - sizes (List[int]): The number of total samples to draw from each split - - synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks behavior. Set to False when we enforce this behavior at higher level. - - Returns: - List[Optional[MidLevelDataset]]: The MidLevelDataset (or None) per split - """ - # short-cut if we are not building on this rank - if torch.distributed.is_initialized() and not self.is_built_on_rank(): - for i in range(len(Split)): - if split[i] is not None and synchronize_ranks: - torch.distributed.barrier() - return [None] * len(Split) - - # Build the low level dataset - low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config) - - # Build the split indices for the low level dataset - num_elements = self.cls.numel_low_level_dataset(low_level_dataset) - split_indices = [] - for i, _ in enumerate(Split): - if split[i] is not None: - beg = int(round(split[i][0] * float(num_elements))) - end = int(round(split[i][1] * float(num_elements))) - split_indices.append(numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32)) - else: - split_indices.append(None) - - # Build the mid level dataset - mid_level_datasets = [] - for i, _split in enumerate(Split): - if split[i] is None: - mid_level_datasets.append(None) - else: - mid_level_datasets.append( - self.build_generic_dataset( - self.cls, - self.is_built_on_rank, - synchronize_ranks, - low_level_dataset, - dataset_path, - split_indices[i], - sizes[i], - _split, - self.config, - ) - ) - - return mid_level_datasets - - @staticmethod - def build_generic_dataset( - cls: Union[Type[DistributedDataset], Callable], - is_built_on_rank: Callable, - synchronize_ranks: bool, - *args: Any, - ) -> Optional[Union[DistributedDataset, Iterable]]: - """Build the DistributedDataset - - Return None if and only if the underlying dataset class is not built on the current rank - and torch.distributed is initialized. - - Args: - cls (Union[Type[DistributedDataset], Callable]): The DistributedDataset class to be built. In special cases, e.g. when we are building the low level dataset for a RawMegatronDataset instance, we can accept a Callable which returns an Iterable. - - synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks behavior. Set to False when we enforce this behavior at higher level. - - args (Tuple[Any]): The positional arguments used to build the provided DistributedDataset class - - Raises: - Exception: When the dataset constructor raises an OSError - - Returns: - Optional[Union[DistributedDataset, Iterable]]: The DistributedDataset instantion, the Iterable instantiation, or None - """ - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - - dataset = None - - # First, build on rank 0 - if rank == 0 and is_built_on_rank(): - try: - dataset = cls(*args) - except OSError as err: - log = ( - f"Failed to write dataset materials to the data cache directory. " - + f"Please supply a directory to which you have write access via " - + f"the path_to_cache attribute in BlendedMegatronDatasetConfig and " - + f"retry. Refer to the preserved traceback above for more information." - ) - raise Exception(log) from err - - if synchronize_ranks: - torch.distributed.barrier() - - # After, build on other ranks - if rank != 0 and is_built_on_rank(): - dataset = cls(*args) - - return dataset - - return cls(*args) - - -def _get_size_per_split_per_dataset( - normalized_weights: List[float], target_size_per_split: List[int] -) -> List[List[int]]: - """Determine the contribution of the MegatronDataset splits to the BlendedDataset splits - - Args: - normalized_weights (List[float]): e.g. [0.3, 0.7] - - target_size_per_split (List[int]): The number of samples to target for each BlendedDataset split - - Returns: - List[List[int]]: The number of samples to request per MegatronDataset per split - """ - assert numpy.isclose(sum(normalized_weights), 1.0) - - # Use 0.5% target margin to ensure we satiate the request - sizes_per_dataset = [ - [int(math.ceil(target_size * weight * 1.005)) for target_size in target_size_per_split] - for weight in normalized_weights - ] - - return sizes_per_dataset +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import logging +import math +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Iterable, List, Optional, Type, Union + +import numpy +import torch + +from megatron.core.datasets.blended_dataset import BlendedDataset +from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig +from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset +from megatron.core.datasets.utils import Split, normalize +from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank +from megatron.core.utils import log_single_rank + +logger = logging.getLogger(__name__) + +MidLevelDataset = MegatronDataset + +TopLevelDataset = Union[BlendedDataset, MidLevelDataset] + +DistributedDataset = Union[ + TopLevelDataset, MidLevelDataset, LowLevelDataset, torch.utils.data.Dataset +] + + +class BlendedMegatronDatasetBuilder(object): + """Builder class for the BlendedDataset and MegatronDataset classes + + Args: + cls (Type[MegatronDataset]): The class to instantiate, must inherit from MegatronDataset + + sizes (List[Optional[int]]): The minimum total number of samples to draw, or None, per split + + is_built_on_rank (Callable): A callable which returns True if the dataset should be built on + the current rank and False otherwise. It should be Megatron Core parallelism aware i.e. + global rank, local group rank, and virtual rank may inform its return value. + + config (BlendedMegatronDatasetConfig): The config object which informs dataset creation + """ + + def __init__( + self, + cls: Type[MidLevelDataset], + sizes: List[int], + is_built_on_rank: Callable, + config: BlendedMegatronDatasetConfig, + ): + self.cls = cls + self.sizes = sizes + self.is_built_on_rank = is_built_on_rank + self.config = config + + log_single_rank( + logger, + logging.INFO, + f"Building {cls.__name__} splits with sizes={self.sizes} and config={self.config}", + ) + + if not self.config.mock: + for split in Split: + size_is_none = self.sizes[split.value] is None + if self.config.blend_per_split is None: + weights_are_none = self.config.blend[1] is None + else: + if self.config.blend_per_split[split.value] is None: + continue + weights_are_none = self.config.blend_per_split[split.value][1] is None + if size_is_none: + assert ( + weights_are_none + ), f"size_is_none => weights_are_none fails for {split.name} split" + + if torch.distributed.is_initialized(): + gb_rank = torch.distributed.get_rank() + vp_rank = get_virtual_pipeline_model_parallel_rank() + if gb_rank == 0 and (vp_rank == 0 or vp_rank is None): + assert ( + self.is_built_on_rank() + ), "is_built_on_rank must return True when global rank = 0 and vp rank = 0" + + def build(self) -> List[Optional[TopLevelDataset]]: + """Build all dataset splits according to the provided blend(s) + + This method is distributed-aware and must be called on all ranks. + + The dataset splits returned can vary according to the config. Supply config.blend and + config.split to build BlendedDataset and/or MegatronDataset splits from the same + distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset + splits from separate distributions. In either case, for each split, handle the following + cases: + + (1) The split is None + - do nothing + + (2) The split has one contributing dataset, and... + + (a) 'size' is not None + - Build a mid-level dataset with low-level dataset sampling in proportion to the + size + + (b) 'size' is None + - Build mid-level datasets with no excess low-level dataset sampling + + (3) The split has multiple contributing datasets, and... + + (a) 'weights' is not None and 'size' is not None + - Build mid-level datasets with low-level dataset sampling in proportion to their + weights and the size + - Build a top-level dataset of length marginally greater than 'size' with mid-level + dataset sampling in proportion to their weights and the size + + (b) 'weights' is not None and 'size' is None + - Error + + (c) 'weights' is None and 'size' is not None + - Build mid-level datasets with no excess low-level dataset sampling + - Build a top-level dataset of length 'size' (capped at the sum of the mid-level + dataset lengths) with mid-level dataset sampling in proportion to their lengths + and the size + + (d) 'weights' is None and 'size' is None + - Build mid-level datasets with no excess low-level dataset sampling + - Build a top-level dataset with no excess mid-level dataset sampling + + Returns: + List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per + split + """ + datasets = self._build_blended_dataset_splits() + + for dataset in datasets: + if dataset is not None and len(dataset) > 0: + if isinstance(dataset, BlendedDataset): + if dataset.built_anew_on_cache_miss or any( + x.built_anew_on_cache_miss for x in dataset.datasets + ): + log_single_rank( + logger, + logging.INFO, + ( + f"Verifying NumPy indices for {type(dataset).__name__} " + f"{dataset.split.name} split" + ), + ) + else: + log_single_rank( + logger, + logging.INFO, + ( + f"NumPy indices for {type(dataset).__name__} {dataset.split.name} " + f"split are fully cached, skipping verification" + ), + ) + continue + # Check blend size + assert dataset.size is None or dataset.size == dataset.dataset_index.shape[0] + # Check blend access of mid-level datasets + dataset_indices, dataset_sizes = numpy.unique( + dataset.dataset_index, return_counts=True + ) + for i, (index, size) in enumerate(zip(dataset_indices, dataset_sizes)): + if len(dataset.datasets[index]) < size: + raise IndexError( + f"The {dataset.split.name} blend oversamples the contributing " + f"datasets and, e.g., requests {size} samples from " + f"{type(dataset.datasets[index]).__name__} {i} with size " + f"{len(dataset.datasets[index])}. This is unexpected. " + f"Please file an issue." + ) + + return datasets + + def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]: + """Build all dataset splits according to the provided blend(s) + + See the BlendedMegatronDatasetBuilder.build alias for more information. + + Returns: + List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per + split + """ + ## + # Return fake "mock" datasets + ## + if self.config.mock: + split = self.config.split_matrix + try: + return self._build_megatron_dataset_splits(None, split, self.sizes) + except Exception as error: + raise Exception( + f"{self.cls.__name__} failed to build as a mock data generator" + ) from error + + ## + # All splits come from the same distribution + ## + elif self.config.blend: + prefixes, weights = self.config.blend + if weights is not None: + weights = normalize(weights) + + split = self.config.split_matrix + + # Blend consists of a single prefix + if len(prefixes) == 1 and weights is None: + return self._build_megatron_dataset_splits(prefixes[0], split, self.sizes) + + # Build the mid-level datasets + if weights is None: + # Build only one "epoch" + sizes_per_dataset_buffer = [[None for split in Split] for prefix in prefixes] + else: + # The number of samples we plan to use per dataset + sizes_per_dataset_target = _get_size_per_split_per_dataset(weights, self.sizes) + # The number of samples we plan to build per dataset + sizes_per_dataset_buffer = _get_size_per_split_per_dataset( + weights, self.sizes, margin=0.5 + ) + + # Build each dataset in parallel + megatron_datasets = self._build_megatron_datasets_parallel( + prefixes, split, sizes_per_dataset_buffer + ) + + # Build the top-level datasets + blended_datasets = [None] * len(Split) + for i in range(len(Split)): + if split[i] is not None: + weights_i = weights + if weights_i is not None and self.sizes[i] is not None: + # Blend according to client-specified weights and client-specified size + size_per_dataset = list(zip(*sizes_per_dataset_target))[i] + size_i = sum(size_per_dataset) + elif weights_i is None: + # Blend according to dataset sizes as-is and (maybe) client-specified size + try: + weights_i = [ + len(megatron_dataset) for megatron_dataset in megatron_datasets[i] + ] + except TypeError: + weights_i = [0 for _ in prefixes] + if self.sizes[i] is not None: + size_i = min(self.sizes[i], sum(weights_i)) + else: + # Build exhaustive indices + size_i = None + else: + raise ValueError( + "Using client-specified weights requires client-specified size" + ) + blended_datasets[i] = self.build_generic_dataset( + BlendedDataset, + self.is_built_on_rank, + True, # synchronize_ranks, default behavior to build on rank-0 first + megatron_datasets[i], + weights_i, + size_i, + self.config, + ) + + return blended_datasets + + ## + # Each split comes from a separate distribution + ## + else: + blended_datasets = [None] * len(Split) + for i in range(len(Split)): + split_spoof = [None] * len(Split) + split_spoof[i] = (0.0, 1.0) + sizes_spoof = [0] * len(Split) + sizes_spoof[i] = self.sizes[i] + + # Blend is provided for the split + blend = self.config.blend_per_split[i] + if blend is not None: + prefixes, weights = blend + if weights is not None: + weights = normalize(weights) + + # Blend consists of a sigle prefix + if len(prefixes) == 1: + blended_datasets[i] = self._build_megatron_dataset_splits( + prefixes[0], split_spoof, sizes_spoof + )[i] + continue + + # Build mid-level datasets + if weights is None: + sizes_per_dataset_buffer = [ + [None for split in Split] for prefix in prefixes + ] + else: + # The number of samples we plan to use per dataset + sizes_per_dataset_target = _get_size_per_split_per_dataset( + weights, sizes_spoof + ) + # The number of samples we plan to build per dataset + sizes_per_dataset_buffer = _get_size_per_split_per_dataset( + weights, sizes_spoof, margin=0.5 + ) + + # Build each dataset in parallel + megatron_datasets = self._build_megatron_datasets_parallel( + prefixes, split_spoof, sizes_per_dataset_buffer + )[i] + + # Build top-level dataset + if weights is not None and self.sizes[i] is not None: + # Blend according to client-specified weights and client-specified size + size_per_dataset = list(zip(*sizes_per_dataset_target))[i] + size = sum(size_per_dataset) + elif weights is None: + # Blend according to dataset sizes as-is and (maybe) client-specified size + try: + weights = [ + len(megatron_dataset) for megatron_dataset in megatron_datasets + ] + except TypeError: + weights = [0 for _ in prefixes] + if self.sizes[i] is not None: + size = min(self.sizes[i], sum(weights)) + else: + # Build exhaustive indices + size = None + else: + raise RuntimeError + blended_datasets[i] = self.build_generic_dataset( + BlendedDataset, + self.is_built_on_rank, + True, # synchronize_ranks, default behavior to build on rank-0 first + megatron_datasets, + weights, + size, + self.config, + ) + + return blended_datasets + + def _build_megatron_datasets_parallel( + self, prefixes: List[str], split: List[float], sizes_per_dataset: List[List[int]] + ) -> List[List[Optional[MegatronDataset]]]: + """Build the megatron datasets for a list of prefixes in parallel + + Args: + prefixes (List[str]): The list of prefix strings + + split (List[float]): The dataset split ratios (must sum to 1.00) + + sizes_per_dataset (List[List[int]]): The number of samples to request + per MegatronDataset per spilt + + Returns: + List[List[Optional[MegatronDataset]]]: For each split, have a list of + MegatronDataset per prefix + """ + + # Helper function to wrap the threading logic + def _threading_helper( + megatron_datasets: List[List[Optional[MegatronDataset]]], + num_workers: int, + prefixes: List[str], + split: List[float], + sizes_per_dataset: List[List[int]], + ) -> None: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + all_futures = [] + for i in range(len(prefixes)): + all_futures.append( + executor.submit( + self._build_megatron_dataset_splits, + prefixes[i], + split, + sizes_per_dataset[i], + False, # synchronize_ranks, barrier is called in this function + ) + ) + for future in all_futures: + try: + megatron_datasets_split = future.result() + for j in range(len(megatron_datasets_split)): + megatron_datasets[j].append(megatron_datasets_split[j]) + except Exception as err: + raise err + + megatron_datasets = [[] for _ in range(len(Split))] + num_dataset_builder_threads = self.config.num_dataset_builder_threads + + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + # First, build on rank 0 + if rank == 0: + num_workers = num_dataset_builder_threads + if num_workers > 1: + # since only rank 0 is running, scale up the thread count + # but not too much to avoid overloading storage on miss path. + # if user set num_dataset_builder_threads to 1, + # i.e. meant for serial build, do not scale up. + num_workers *= min(2, max(1, torch.cuda.device_count())) + _threading_helper( + megatron_datasets, num_workers, prefixes, split, sizes_per_dataset + ) + + torch.distributed.barrier() + + # Then, build on other ranks; guaranteed to be data_cache hit + if rank != 0: + _threading_helper( + megatron_datasets, + num_dataset_builder_threads, + prefixes, + split, + sizes_per_dataset, + ) + else: + _threading_helper( + megatron_datasets, num_dataset_builder_threads, prefixes, split, sizes_per_dataset + ) + + return megatron_datasets + + def _build_megatron_dataset_splits( + self, + dataset_path: Optional[str], + split: List[float], + sizes: List[int], + synchronize_ranks: bool = True, + ) -> List[Optional[MidLevelDataset]]: + """Build each MidLevelDataset split from a single LowLevelDataset + + Args: + dataset_path (Optional[str]): The path on disk which defines the underlying + LowLevelDataset, or None for mock dataset classes + + split (List[Tuple[float, float]]): The dataset split matrix + + sizes (List[int]): The number of total samples to draw from each split + + synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks + behavior. Set to False when we enforce this behavior at higher level. + + Returns: + List[Optional[MidLevelDataset]]: The MidLevelDataset (or None) per split + """ + # short-cut if we are not building on this rank + if torch.distributed.is_initialized() and not self.is_built_on_rank(): + for i in range(len(Split)): + if split[i] is not None and synchronize_ranks: + torch.distributed.barrier() + return [None] * len(Split) + + # Build the low level dataset + low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config) + + # Build the split indices for the low level dataset + num_elements = self.cls.numel_low_level_dataset(low_level_dataset) + split_indices = [] + for i, _ in enumerate(Split): + if split[i] is not None: + beg = int(round(split[i][0] * float(num_elements))) + end = int(round(split[i][1] * float(num_elements))) + split_indices.append(numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32)) + else: + split_indices.append(None) + + # Build the mid level dataset + mid_level_datasets = [] + for i, _split in enumerate(Split): + if split[i] is None: + mid_level_datasets.append(None) + else: + mid_level_datasets.append( + self.build_generic_dataset( + self.cls, + self.is_built_on_rank, + synchronize_ranks, + low_level_dataset, + dataset_path, + split_indices[i], + sizes[i], + _split, + self.config, + ) + ) + + return mid_level_datasets + + @staticmethod + def build_generic_dataset( + cls: Union[Type[DistributedDataset], Callable], + is_built_on_rank: Callable, + synchronize_ranks: bool, + *args: Any, + ) -> Optional[Union[DistributedDataset, Iterable]]: + """Build the DistributedDataset + + Return None if and only if the underlying dataset class is not built on the current rank + and torch.distributed is initialized. + + Args: + cls (Union[Type[DistributedDataset], Callable]): The DistributedDataset class to be + built. In special cases, e.g. when we are building the low level dataset for a + RawMegatronDataset instance, we can accept a Callable which returns an Iterable. + + synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks + behavior. Set to False when we enforce this behavior at higher level. + + args (Tuple[Any]): The positional arguments used to build the provided + DistributedDataset class + + Raises: + Exception: When the dataset constructor raises an OSError + + Returns: + Optional[Union[DistributedDataset, Iterable]]: The DistributedDataset instantion, the + Iterable instantiation, or None + """ + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + + dataset = None + + # First, build on rank 0 + if rank == 0 and is_built_on_rank(): + try: + dataset = cls(*args) + except OSError as err: + log = ( + f"Failed to write dataset materials to the data cache directory. Please " + f"supply a directory to which you have write access via the path_to_cache " + f"attribute in BlendedMegatronDatasetConfig and retry. Refer to the " + f"preserved traceback above for more information." + ) + raise Exception(log) from err + + if synchronize_ranks: + torch.distributed.barrier() + + # After, build on other ranks + if rank != 0 and is_built_on_rank(): + dataset = cls(*args) + + return dataset + + return cls(*args) + + +def _get_size_per_split_per_dataset( + normalized_weights: List[float], target_size_per_split: List[int], margin: float = 0.0 +) -> List[List[int]]: + """Determine the contribution of the MegatronDataset splits to the BlendedDataset splits + + Args: + normalized_weights (List[float]): e.g. [0.3, 0.7] + + target_size_per_split (List[int]): The number of samples to target for each BlendedDataset + split + + margin (float): The relative quantity of extra samples to build per per split per dataset, + as a percentage + + Returns: + List[List[int]]: The number of samples to request per MegatronDataset per split + """ + assert numpy.isclose(sum(normalized_weights), 1.0) + + # Use margin as buffer to ensure we satiate the request + sizes_per_dataset = [ + [ + int(math.ceil(math.ceil(target_size * weight) * (1 + margin / 100))) + for target_size in target_size_per_split + ] + for weight in normalized_weights + ] + + return sizes_per_dataset diff --git a/megatron/core/datasets/blended_megatron_dataset_config.py b/megatron/core/datasets/blended_megatron_dataset_config.py index 52bc31f62ef803923a48f3c9726f058ea77586bd..f79d0ebf83bf6a5723f0cab5dd08fe7751b3da21 100644 --- a/megatron/core/datasets/blended_megatron_dataset_config.py +++ b/megatron/core/datasets/blended_megatron_dataset_config.py @@ -1,177 +1,172 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import functools -import logging -import re -from dataclasses import dataclass, field -from typing import List, Optional, Tuple - -from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer -from megatron.core.datasets.utils import Split, log_single_rank, normalize - -logger = logging.getLogger(__name__) - - -@dataclass -class BlendedMegatronDatasetConfig: - """Configuration object for Megatron Core datasets""" - - random_seed: int - """The seed for all RNG during dataset creation.""" - - sequence_length: int - """The sequence length.""" - - blend: Optional[Tuple[List[str], Optional[List[float]]]] = None - """The blend, consisting of a list of dataset prefixes and optionally a list of dataset - weights. For example, [["dataset-path1", "dataset-path2"], [0.3, 0.7]]. When the weights are - None, they are inferred from the lengths of the contributing datasets. Not to be used with - 'blend_per_split'. Defaults to None. - """ - - blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] = None - """A set of blends, as defined above, one for each split distribution. Not to be used with - 'blend'. Defauls to None. - """ - - renormalize_blend_weights: bool = False - """Renormalize the blend weights to account for mid-level dataset oversampling done to ensure - fulfillmenet of the of the requested number of samples. Defaults to False for backward - comparability in the data sample order. - """ - - split: Optional[str] = None - """The split string, a comma separated weighting for the dataset splits when drawing samples - from a single distribution. Not to be used with 'blend_per_split'. Defaults to None. - """ - - split_matrix: Optional[List[Tuple[float, float]]] = field(init=False, default=None) - """The split matrix consisting of non-overlapping book-ends of each split in order. For more - information, refer to 'convert_split_vector_to_split_matrix'. Created automatically from - 'split'. Not to be passed in to the constructor. - """ - - num_dataset_builder_threads: int = 1 - """The number of threads to use for dataset building.""" - - path_to_cache: Optional[str] = None - """Where all re-useable dataset indices are to be cached.""" - - mmap_bin_files: bool = True - """Whether to mmap the .bin files or use file pointers.""" - - mock: bool = field(init=False, default=False) - """Whether to bypass real data loading and validation in favor of mock data generation. - Created automatically from 'blend' and 'blend_per_split'. Not to be passed in to the - constructor. - """ - - tokenizer: Optional[MegatronTokenizer] = None - """The MegatronTokenizer instance or None. Required for datasets which do online tokenization.""" - - def __post_init__(self) -> None: - """Do asserts and set fields post init""" - if self.blend_per_split is not None and any(self.blend_per_split): - assert self.blend is None, "blend and blend_per_split are incompatible" - assert self.split is None, "split and blend_per_split are incompatible" - assert len(self.blend_per_split) == len( - Split - ), f"blend_per_split must contain {len(Split)} blends" - for split in Split: - if self.blend_per_split[split.value] is None: - log_single_rank( - logger, logging.INFO, f"blend not provided for {split.name} split" - ) - else: - assert self.blend_per_split[split.value][1] is None or len( - self.blend_per_split[split.value][0] - ) == len( - self.blend_per_split[split.value][1] - ), "blend per split prefixes and weights must be equal in number" - else: - if self.blend is not None: - assert self.blend[1] is None or len(self.blend[0]) == len( - self.blend[1] - ), "blend prefixes and weights must be equal in number" - assert self.split is not None, "split must be provided when blend is not None" - else: - self.mock = True - log_single_rank( - logger, - logging.INFO, - f"Let mock = True, as both blend and blend_per_split are None", - ) - self.split = "1,1,1" - log_single_rank( - logger, - logging.INFO, - f"Let split = {self.split}, an arbitrarily even split, as mock is True", - ) - split_vector = parse_and_normalize_split(self.split) - self.split_matrix = convert_split_vector_to_split_matrix(split_vector) - log_single_rank(logger, logging.INFO, f"Let split_matrix = {self.split_matrix}") - - -def parse_and_normalize_split(split: str) -> List[float]: - """Parse the dataset split ratios from a string - - Args: - split (str): The train valid test split string e.g. "99,1,0" - - Returns: - List[float]: The trian valid test split ratios e.g. [0.99, 0.01, 0.0] - """ - split = list(map(float, re.findall(r"[.0-9]+", split))) - split = split + [0.0 for _ in range(len(Split) - len(split))] - - assert len(split) == len(Split) - assert all(map(lambda _: _ >= 0.0, split)) - - split = normalize(split) - - return split - - -def convert_split_vector_to_split_matrix( - vector_a: List[float], vector_b: Optional[List[float]] = None -) -> List[Optional[Tuple[float, float]]]: - """Build the split matrix from one or optionally two contributing split vectors. - - Ex. a standard conversion: - - [0.99, 0.01, 0.0] -> [(0, 0.99), (0.99, 1.0), None] - - Ex. a conversion for Retro when Retro pretraining uses a [0.99, 0.01, 0.0] split and Retro - preprocessing used a [0.98, 0.02, 0.0] split: - - [0.99, 0.01, 0.0], [0.98, 0.02, 0.0] -> [(0, 0.98), (0.99, 1.0), None] - - Args: - vector_a (List[float]): The primary split vector - - vector_b (Optional[List[float]]): An optional secondary split vector which constrains the primary split vector. Defaults to None. - - Returns: - List[Tuple[float, float]]: The split matrix consisting of book-ends of each split in order - """ - if vector_b is None: - vector_b = vector_a - - # [.900, .090, .010] -> [0.00, .900, .990, 100] - expansion_a = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_a]) - expansion_b = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_b]) - - # [0.00, .900, .990, 100.0] -> [(0.00, .900), (.900, .990), (.990, 100)] - bookends_a = list(zip(expansion_a[:-1], expansion_a[1:])) - bookends_b = list(zip(expansion_b[:-1], expansion_b[1:])) - - # gather per-split overlap or None - matrix = [] - for bookend_a, bookend_b in zip(bookends_a, bookends_b): - if min(bookend_a[1], bookend_b[1]) <= max(bookend_a[0], bookend_b[0]): - overlap = None - else: - overlap = (max(bookend_a[0], bookend_b[0]), min(bookend_a[1], bookend_b[1])) - matrix.append(overlap) - - return matrix +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import functools +import logging +import re +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer +from megatron.core.datasets.utils import Split, log_single_rank, normalize + +logger = logging.getLogger(__name__) + + +@dataclass +class BlendedMegatronDatasetConfig: + """Configuration object for Megatron Core datasets""" + + random_seed: int + """The seed for all RNG during dataset creation.""" + + sequence_length: int + """The sequence length.""" + + blend: Optional[Tuple[List[str], Optional[List[float]]]] = None + """The blend, consisting of a list of dataset prefixes and optionally a list of dataset + weights. For example, [["dataset-path1", "dataset-path2"], [0.3, 0.7]]. When the weights are + None, they are inferred from the lengths of the contributing datasets. Not to be used with + 'blend_per_split'. Defaults to None. + """ + + blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] = None + """A set of blends, as defined above, one for each split distribution. Not to be used with + 'blend'. Defauls to None. + """ + + split: Optional[str] = None + """The split string, a comma separated weighting for the dataset splits when drawing samples + from a single distribution. Not to be used with 'blend_per_split'. Defaults to None. + """ + + split_matrix: Optional[List[Tuple[float, float]]] = field(init=False, default=None) + """The split matrix consisting of non-overlapping book-ends of each split in order. For more + information, refer to 'convert_split_vector_to_split_matrix'. Created automatically from + 'split'. Not to be passed in to the constructor. + """ + + num_dataset_builder_threads: int = 1 + """The number of threads to use for dataset building.""" + + path_to_cache: Optional[str] = None + """Where all re-useable dataset indices are to be cached.""" + + mmap_bin_files: bool = True + """Whether to mmap the .bin files or use file pointers.""" + + mock: bool = field(init=False, default=False) + """Whether to bypass real data loading and validation in favor of mock data generation. + Created automatically from 'blend' and 'blend_per_split'. Not to be passed in to the + constructor. + """ + + tokenizer: Optional[MegatronTokenizer] = None + """The MegatronTokenizer instance. Required for datasets that do online tokenization.""" + + def __post_init__(self) -> None: + """Do asserts and set fields post init""" + if self.blend_per_split is not None and any(self.blend_per_split): + assert self.blend is None, "blend and blend_per_split are incompatible" + assert self.split is None, "split and blend_per_split are incompatible" + assert len(self.blend_per_split) == len( + Split + ), f"blend_per_split must contain {len(Split)} blends" + for split in Split: + if self.blend_per_split[split.value] is None: + log_single_rank( + logger, logging.INFO, f"blend not provided for {split.name} split" + ) + else: + assert self.blend_per_split[split.value][1] is None or len( + self.blend_per_split[split.value][0] + ) == len( + self.blend_per_split[split.value][1] + ), "blend per split prefixes and weights must be equal in number" + else: + if self.blend is not None: + assert self.blend[1] is None or len(self.blend[0]) == len( + self.blend[1] + ), "blend prefixes and weights must be equal in number" + assert self.split is not None, "split must be provided when blend is not None" + else: + self.mock = True + log_single_rank( + logger, + logging.INFO, + f"Let mock = True, as both blend and blend_per_split are None", + ) + self.split = "1,1,1" + log_single_rank( + logger, + logging.INFO, + f"Let split = {self.split}, an arbitrarily even split, as mock is True", + ) + split_vector = parse_and_normalize_split(self.split) + self.split_matrix = convert_split_vector_to_split_matrix(split_vector) + log_single_rank(logger, logging.INFO, f"Let split_matrix = {self.split_matrix}") + + +def parse_and_normalize_split(split: str) -> List[float]: + """Parse the dataset split ratios from a string + + Args: + split (str): The train valid test split string e.g. "99,1,0" + + Returns: + List[float]: The trian valid test split ratios e.g. [0.99, 0.01, 0.0] + """ + split = list(map(float, re.findall(r"[.0-9]+", split))) + split = split + [0.0 for _ in range(len(Split) - len(split))] + + assert len(split) == len(Split) + assert all(map(lambda _: _ >= 0.0, split)) + + split = normalize(split) + + return split + + +def convert_split_vector_to_split_matrix( + vector_a: List[float], vector_b: Optional[List[float]] = None +) -> List[Optional[Tuple[float, float]]]: + """Build the split matrix from one or optionally two contributing split vectors. + + Ex. a standard conversion: + + [0.99, 0.01, 0.0] -> [(0, 0.99), (0.99, 1.0), None] + + Ex. a conversion for Retro when Retro pretraining uses a [0.99, 0.01, 0.0] split and Retro + preprocessing used a [0.98, 0.02, 0.0] split: + + [0.99, 0.01, 0.0], [0.98, 0.02, 0.0] -> [(0, 0.98), (0.99, 1.0), None] + + Args: + vector_a (List[float]): The primary split vector + + vector_b (Optional[List[float]]): An optional secondary split vector which constrains the + primary split vector. Defaults to None. + + Returns: + List[Tuple[float, float]]: The split matrix consisting of book-ends of each split in order + """ + if vector_b is None: + vector_b = vector_a + + # [.900, .090, .010] -> [0.00, .900, .990, 100] + expansion_a = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_a]) + expansion_b = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_b]) + + # [0.00, .900, .990, 100.0] -> [(0.00, .900), (.900, .990), (.990, 100)] + bookends_a = list(zip(expansion_a[:-1], expansion_a[1:])) + bookends_b = list(zip(expansion_b[:-1], expansion_b[1:])) + + # gather per-split overlap or None + matrix = [] + for bookend_a, bookend_b in zip(bookends_a, bookends_b): + if min(bookend_a[1], bookend_b[1]) <= max(bookend_a[0], bookend_b[0]): + overlap = None + else: + overlap = (max(bookend_a[0], bookend_b[0]), min(bookend_a[1], bookend_b[1])) + matrix.append(overlap) + + return matrix diff --git a/megatron/core/dist_checkpointing/__init__.py b/megatron/core/dist_checkpointing/__init__.py index eb7ad78a42b6b6d2d5f43867c6c3ecda27c0a831..c9d059b183c4a9f4c3e68e86fd807ee5cfcb600d 100644 --- a/megatron/core/dist_checkpointing/__init__.py +++ b/megatron/core/dist_checkpointing/__init__.py @@ -1,12 +1,12 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -from .core import check_is_distributed_checkpoint -from .mapping import LocalNonpersistentObject, LocalNonpersitentObject, ShardedTensor -from .serialization import ( - load, - load_common_state_dict, - load_plain_tensors, - load_tensors_metadata, - remove_sharded_tensors, - save, -) +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +from .core import check_is_distributed_checkpoint +from .mapping import LocalNonpersistentObject, ShardedObject, ShardedTensor +from .serialization import ( + load, + load_common_state_dict, + load_plain_tensors, + load_tensors_metadata, + remove_sharded_tensors, + save, +) diff --git a/megatron/core/dist_checkpointing/exchange_utils.py b/megatron/core/dist_checkpointing/exchange_utils.py index 2106fe574c2dfa74679f26ef92323968bee87c6a..8a9b52cef8264bb9e51e0a96f4e531dedd4dd2de 100644 --- a/megatron/core/dist_checkpointing/exchange_utils.py +++ b/megatron/core/dist_checkpointing/exchange_utils.py @@ -1,519 +1,544 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -"""Utilities for exchanging data between ranks.""" - -import logging -from collections import defaultdict -from functools import reduce -from itertools import zip_longest -from time import time -from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast - -import numpy as np -import torch - -from .core import CheckpointingException -from .dict_utils import nested_values -from .mapping import ShardedStateDict, ShardedTensor, is_main_replica -from .utils import _sharded_tensor_shard_id, _ShardId - -# TODO: remove TE references once the TE bug is fixed -# Check if Transformer Engine has Float8Tensor class -HAVE_TE_FLOAT8TENSOR = False -try: - from transformer_engine.pytorch.float8_tensor import Float8Tensor - - HAVE_TE_FLOAT8TENSOR = True -except (ImportError, ModuleNotFoundError): - # Float8Tensor not found - pass - - -def is_float8tensor(tensor: torch.Tensor) -> bool: - """Check if a tensor is a Transformer Engine Float8Tensor""" - return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) - - -logger = logging.getLogger(__name__) - - -class ShardDistribution(NamedTuple): - """Represents a distribution of ShardedTensors. - - Given distribution is valid only for a specific parallelization group, - which is implicit here (not referenced by this class). - - Args: - main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold - the main replica for a given shard - shards_in_this_group (Set[_ShardId]): which shards have a main replica - in this parallelization group - shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor - identifier to the original ShardedTensor - all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks - need a given shard in a given parallelization group - - """ - - main_rank_for_shard: Dict[_ShardId, int] - shards_in_this_group: Set[_ShardId] - shard_to_metadata: Dict[_ShardId, ShardedTensor] - all_ranks_for_shard: Dict[_ShardId, List[int]] - - -def _shard_size(sh_ten: ShardedTensor): - """Returns size in bytes of a given sharded tensor.""" - if sh_ten.flattened_range is None: - numel = np.product(sh_ten.local_shape) - else: - numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start - return numel * torch._utils._element_size(sh_ten.dtype) - - -def _get_empty_tensor_for_exchange( - shard_id: _ShardId, - needed_shards: Dict[_ShardId, ShardedTensor], - unneeded_shards: Dict[_ShardId, ShardedTensor], - loaded_tensors: Dict[_ShardId, torch.Tensor], -) -> Tuple[torch.Tensor, Optional[torch.device]]: - """Determines the empty tensor to use for exchange. - - If shard_id is needed by this rank, it will be in the `unloaded_shards`. - Otherwise, the metadata for this tensor can be found in `shard_to_metadata` - - Args: - shard_id (_ShardId): shard_id that will be exchanged - needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids - to metadata for shards needed by this rank - unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids - to metadata for shards that can be discarded after exchange - loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors - are placed in - - Returns: - Tuple[torch.Tensor, Optional[torch.device]]: empty CUDA tensor to be exchanged, - and the device of the original state dict tensor (if there was any) - """ - local_unloaded_sh_ten = needed_shards.get(shard_id) - if local_unloaded_sh_ten is None: - orig_device = None # this tensor will be discarded anyway - sh_ten = unneeded_shards[shard_id] - if sh_ten.data is None: - sh_ten.init_data('cuda') - tensor = sh_ten.data - sh_ten.data = None # won't be used. free memory - else: - tensor = sh_ten.data - if tensor.device.type == 'cpu': - tensor = torch.empty_like(tensor, device='cuda') - else: - local_unloaded_sh_ten.init_data('cuda') - orig_device = local_unloaded_sh_ten.data.device - tensor = local_unloaded_sh_ten.data - if tensor.device.type == 'cpu': - tensor = torch.empty_like(tensor, device='cuda') - loaded_tensors[shard_id] = tensor - return tensor, orig_device - - -T = TypeVar('T') - - -def distribute_shards_to_ranks( - shard_to_ranks: Dict[T, List[int]], shard_to_size: Dict[T, int], num_ranks: int -) -> Dict[T, int]: - """Computes uniform distribution of workload across ranks, based on sizes. - - Currently, the assignment is greedy, based on: - 1. Firstly, the coverage of each shard - (how many ranks the shard is available on; lower coverage is assigned first) - 2. Secondly, the size of each shard (larger size is assigned first) - 3. Finally, shard id for differentiation. - - Third step is added because we rely on the fact that - the assignment is deterministic on all ranks. - - Args: - shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards - shard_to_size (Dict[T, int]): sizes of each shard - num_ranks (int): number of ranks in the parallelization group - - Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work - to achieve maximal uniformity) - """ - shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()} - shard_to_saving_rank = {} - rank_sizes = [(0, rank) for rank in range(num_ranks)] - - # start from tensors of lowest coverage, then go by tensor size from largest (hence minus size) - for shard_id, shard_ranks in sorted( - shard_to_ranks.items(), - key=lambda sh_id_ranks: ( - len(sh_id_ranks[1]), - -shard_to_size[sh_id_ranks[0]], - sh_id_ranks[0], - ), - ): - # assign greedily to the least occupied rank - size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks) - - shard_to_saving_rank[shard_id] = rank - rank_sizes[rank] = (size + shard_to_size[shard_id], rank) - - logger.debug(f'distribute_shards_to_ranks distribution: {rank_sizes}') - - return shard_to_saving_rank - - -def determine_main_replica_uniform_distribution( - sharded_state_dict: ShardedStateDict, - parallelization_group: torch.distributed.ProcessGroup, - ignore_groups: bool = False, -) -> Optional[ShardDistribution]: - """Computes the save distribution. - - Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution` - which applies the computed save distribution. - - We rely on the fact that the assignment algorithm is deterministic on all ranks, - so there is no extra communication needed after metadata exchange. - - Args: - sharded_state_dict (ShardedStateDict): state dict to compute the distribution of - parallelization_group (ProcessGroup): distribution will be computed - within this process group - ignore_groups (bool, optional): whether the distribution defines groups. - This option is primarily used during loading, as it ensures that all replicas, - including non-main ones, are loaded by this parallelization group - Defaults to False. - - Returns (ShardDistribution, optional): distribution that can be used to apply the - parallelization. Returns None if the process_group is trivial (1 rank) - - """ - group_size = torch.distributed.get_world_size(group=parallelization_group) - if group_size <= 1: - return - local_shards = list( - sh_base - for sh_base in nested_values(sharded_state_dict) - if isinstance(sh_base, ShardedTensor) - ) - local_shards_no_data = [ten.without_data() for ten in local_shards] - - all_shards = [None] * torch.distributed.get_world_size(group=parallelization_group) - torch.distributed.all_gather_object( - all_shards, local_shards_no_data, group=parallelization_group - ) - - shard_to_ranks = defaultdict(list) - shard_to_size = {} - shard_to_metadata = {} - shards_in_this_parallelization_group: Set[_ShardId] = set() - for rank, rank_shards in enumerate(all_shards): - for sh_ten in rank_shards: - shard_id = _sharded_tensor_shard_id(sh_ten) - shard_to_ranks[shard_id].append(rank) - if shard_id not in shard_to_size: - shard_to_size[shard_id] = _shard_size(sh_ten) - shard_to_metadata[shard_id] = sh_ten - if is_main_replica(sh_ten.replica_id) or ignore_groups: - shards_in_this_parallelization_group.add(shard_id) - - shard_to_ranks = { - k: v for k, v in shard_to_ranks.items() if k in shards_in_this_parallelization_group - } - - shard_to_saving_rank = distribute_shards_to_ranks( - shard_to_ranks, shard_to_size, len(all_shards) - ) - - return ShardDistribution( - shard_to_saving_rank, - shards_in_this_parallelization_group, - shard_to_metadata, - shard_to_ranks, - ) - - -@torch.no_grad() -def exchange_loaded_tensors_gather_rounds( - loaded_tensors: Dict[_ShardId, torch.Tensor], - unloaded_shards: Dict[_ShardId, ShardedTensor], - shard_distribution: ShardDistribution = None, - parallelization_group: Optional[torch.distributed.ProcessGroup] = None, -) -> Dict[_ShardId, torch.Tensor]: - """Exchange the tensors loaded by different ranks with several all_gather calls. - - Groups tensors by dtype, divide tensors that will be exchanged into rounds - and execute all_gather for tensors from each round. - - Note: the loading is distributed across ranks based on total loaded size - in bytes, so there is no guarantee that number of rounds needed for each - rank will be similar, which might result in a lot of almost empty - all_gathers. The solution would be to group all tensors into a one - bytes tensor and do a single all_gather (with similarly sized messages). - - Args: - loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor - shard ids to tensors already loaded by this rank. - unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor - shard ids to ShardedTensors that aren't loaded yet. - shard_distribution (ShardDistribution): distribution of all shards - parallelization_group (ProcessGroup, optional): process group used for load - distribution. Tensors will be exchanged within this group - - Returns: - Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors - needed by this rank to load a given state dict. Includes - previously loaded tensors (from `loaded_tensors` input) - """ - main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution - local_rank = torch.distributed.get_rank(group=parallelization_group) - - all_loaded_tensors = dict(loaded_tensors) - - # Group by dtype so that we all_gather tensors of the same dtype - for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str): - - start = time() - # shards_by_rank maps rank to tensors loaded by this rank - shards_by_rank: List[List[torch.Tensor]] = [ - [] for _ in range(torch.distributed.get_world_size(group=parallelization_group)) - ] - for shard_id, rank in main_rank_for_shard.items(): - if len(all_ranks_for_shard[shard_id]) == 1: - assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], ( - f'When there is only 1 ranks that needs a given shard,' - f' it should be the loading rank.' - f' Got: needs [{all_ranks_for_shard[shard_id][0]}]' - f' vs loads [{main_rank_for_shard[shard_id]}]' - ) - # Skipping the exchange since only the loading rank needs this tensor - # TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` - # case, e.g. P2P exchange. Currently handling this case saves most of the - # work though. - continue - if shard_to_metadata[shard_id].dtype == dtype: - shards_by_rank[rank].append(shard_id) - - # Transpose `shards_by_rank` to form exchange rounds - shards_by_round = zip_longest(*shards_by_rank, fillvalue=None) - for round_idx, round_shard_ids in enumerate(shards_by_round): - round_tensors = [] - orig_devices = {} - for rank, shard_id in enumerate(round_shard_ids): - if shard_id is None: - # if no more useful data, the given rank will exchange empty tensor - local_ten = torch.empty(0, dtype=dtype, device='cuda') - orig_device = None - else: - assert isinstance(shard_id, tuple), type(shard_id) - if rank == local_rank: - assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys()) - orig_device = all_loaded_tensors[shard_id] - all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda() - local_ten = all_loaded_tensors[shard_id] - else: - local_ten, orig_device = _get_empty_tensor_for_exchange( - shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors - ) - # Because of a TE bug, we have to exchange a nominal dtype instead of FP8 - # It's ok to keep the nominal dtype after exchange, because TE will handle - # this during state dict load. - # TODO: remove it once the bug is fixed - if is_float8tensor(local_ten): - local_ten = local_ten.from_float8() - all_loaded_tensors[shard_id] = local_ten - - round_tensors.append(local_ten) - if orig_device is not None: - orig_devices[shard_id] = orig_device - - torch.distributed.all_gather( - list(round_tensors), - round_tensors[local_rank], - group=parallelization_group, - async_op=False, - ) - - # Move tensors back to CPU if originally was on CPU - for shard_id, orig_device in orig_devices.items(): - all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device) - - del round_tensors # remove tensor references - - end = time() - if torch.distributed.get_rank() == 0: - logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s') - - return all_loaded_tensors - - -def exchange_loaded_tensors_gather_object( - loaded_tensors: Dict[_ShardId, torch.Tensor], - unloaded_shards: Dict[_ShardId, ShardedTensor], - shard_distribution: ShardDistribution, - parallelization_group: Optional[torch.distributed.ProcessGroup] = None, -) -> Dict[_ShardId, torch.Tensor]: - """Exchange the tensors loaded by different ranks with a simple all_gather_object call. - - This version can be used for debugging purposes do to its simplistic - implementation. Shouldn't be used if performance is important. - - Args: - loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor - shard ids to tensors already loaded by this rank. - unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor - shard ids to ShardedTensors that aren't loaded yet. - shard_distribution (ShardDistribution): distribution of all shards - parallelization_group (ProcessGroup, optional): process group used for load - distribution. Tensors will be exchanged within this group - - Returns: - Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors - needed by this rank to load a given state dict. Includes - previously loaded tensors (from `loaded_tensors` input) - - """ - all_loaded_tensors_list = [None] * torch.distributed.get_world_size(group=parallelization_group) - torch.distributed.all_gather_object( - all_loaded_tensors_list, loaded_tensors, group=parallelization_group - ) - all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list) - all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list) - - # Error checks - if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)): - err_msg = 'Duplicate shard ids loaded by different ranks' - if torch.distributed.get_rank() == 0: - logger.error( - f'{err_msg}. Shards ids by rank:' - f' {[lt.keys() for lt in all_loaded_tensors_list]}' - ) - raise CheckpointingException(err_msg) - - return all_loaded_tensors - - -@torch.no_grad() -def exchange_loaded_tensors_broadcast( - loaded_tensors: Dict[_ShardId, torch.Tensor], - unloaded_shards: Dict[_ShardId, ShardedTensor], - shard_distribution: ShardDistribution, - parallelization_group: Optional[torch.distributed.ProcessGroup] = None, -) -> Dict[_ShardId, torch.Tensor]: - """Exchange the tensors loaded by different ranks by a series of broadcasts. - - For each rank for each loaded tensor do a broadcast to the whole group. - A reasonable tradeoff in terms of performance and simplicity. - - Args: - loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor - shard ids to tensors already loaded by this rank. - unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor - shard ids to ShardedTensors that aren't loaded yet. - shard_distribution (ShardDistribution): distribution of all shards - parallelization_group (ProcessGroup, optional): process group used for load - distribution. Tensors will be exchanged within this group - - Returns: - Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors - needed by this rank to load a given state dict. Includes - previously loaded tensors (from `loaded_tensors` input) - """ - main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution - local_rank = torch.distributed.get_rank(group=parallelization_group) - - all_loaded_tensors = dict(loaded_tensors) - - start = time() - - for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()): - if len(all_ranks_for_shard[shard_id]) == 1: - assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], ( - f'When there is only 1 ranks that needs a given shard,' - f' it should be the loading rank.' - f'Got: needs [{all_ranks_for_shard[shard_id][0]}]' - f' vs loads [{main_rank_for_shard[shard_id]}]' - ) - # Skipping the exchange since only the loading rank needs this tensor - # TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` case, - # e.g. P2P exchange. Currently handling this case saves most of the work though. - continue - if rank == local_rank: - assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys()) - orig_device = all_loaded_tensors[shard_id].device - local_ten = all_loaded_tensors[shard_id].cuda() - else: - local_ten, orig_device = _get_empty_tensor_for_exchange( - shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors - ) - - # Because of a TE bug, we have to exchange a nominal dtype instead of FP8 - # It's ok to keep the nominal dtype after exchange, because TE will handle - # this during state dict load. - # TODO: remove it once the bug is fixed - if is_float8tensor(local_ten): - local_ten = local_ten.from_float8() - all_loaded_tensors[shard_id] = local_ten - - global_src_rank = ( - rank - if parallelization_group == None - else torch.distributed.get_global_rank(parallelization_group, rank) - ) - # We can do async_op=True only if there is no CPU-copy follow-up - torch.distributed.broadcast( - local_ten, - src=global_src_rank, - group=parallelization_group, - async_op=orig_device is None, - ) - # Move tensor back to CPU if originally was on CPU - if orig_device is not None: - all_loaded_tensors[shard_id] = local_ten.to(orig_device) - del local_ten - - end = time() - if torch.distributed.get_rank() == 0: - logger.debug(f'exchange broadcast schedule took {end - start}s') - - return all_loaded_tensors - - -def exchange_by_distribution( - loaded_tensors: Dict[_ShardId, torch.Tensor], - unloaded_shards: Dict[_ShardId, ShardedTensor], - shard_distribution: ShardDistribution = None, - parallelization_group: Optional[torch.distributed.ProcessGroup] = None, - exchange_algo='broadcast', -) -> Dict[_ShardId, torch.Tensor]: - """Exchange tensors loaded by different ranks using the specified exchange_algo. - - Args: - loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor - shard ids to tensors already loaded by this rank. - unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor - shard ids to ShardedTensors that aren't loaded yet. - shard_distribution (ShardDistribution): distribution of all shards - parallelization_group (ProcessGroup, optional): process group used for load - distribution. Tensors will be exchanged within this group - exchange_algo (str): The algorithm used for performing exchanges. - Defaults to 'broadcast'. - - Returns: - Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors - needed by this rank to load a given state dict. Includes - previously loaded tensors (from `loaded_tensors` input) - """ - - if exchange_algo == 'gather_object': - exchange_fn = exchange_loaded_tensors_gather_object - elif exchange_algo == 'gather_rounds': - exchange_fn = exchange_loaded_tensors_gather_rounds - elif exchange_algo == 'broadcast': - exchange_fn = exchange_loaded_tensors_broadcast - else: - raise NotImplementedError(f'Unrecognized gather algorithm: {exchange_algo}') - return exchange_fn(loaded_tensors, unloaded_shards, shard_distribution, parallelization_group) +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +"""Utilities for exchanging data between ranks.""" + +import logging +from collections import defaultdict +from functools import reduce +from itertools import zip_longest +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast + +import numpy as np +import torch + +from .core import CheckpointingException +from .dict_utils import nested_values +from .mapping import ShardedStateDict, ShardedTensor, is_main_replica +from .utils import _sharded_tensor_shard_id, _ShardId, debug_time + +# TODO: remove TE references once the TE bug is fixed +# Check if Transformer Engine has Float8Tensor class +HAVE_TE_FLOAT8TENSOR = False +try: + from transformer_engine.pytorch.float8_tensor import Float8Tensor + + HAVE_TE_FLOAT8TENSOR = True +except (ImportError, ModuleNotFoundError): + # Float8Tensor not found + pass + + +def is_float8tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a Transformer Engine Float8Tensor""" + return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) + + +logger = logging.getLogger(__name__) + + +class ShardDistribution(NamedTuple): + """Represents a distribution of ShardedTensors. + + Given distribution is valid only for a specific parallelization group, + which is implicit here (not referenced by this class). + + Args: + main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold + the main replica for a given shard + shards_in_this_group (Set[_ShardId]): which shards have a main replica + in this parallelization group + shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor + identifier to the original ShardedTensor + all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks + need a given shard in a given parallelization group + """ + + main_rank_for_shard: Dict[_ShardId, int] + shards_in_this_group: Set[_ShardId] + shard_to_metadata: Dict[_ShardId, ShardedTensor] + all_ranks_for_shard: Dict[_ShardId, List[int]] + + +def _shard_size(sh_ten: ShardedTensor): + """Returns size in bytes of a given sharded tensor.""" + if sh_ten.flattened_range is None: + numel = np.product(sh_ten.local_shape) + else: + numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start + return numel * torch._utils._element_size(sh_ten.dtype) + + +def _get_empty_tensor_for_exchange( + shard_id: _ShardId, + needed_shards: Dict[_ShardId, ShardedTensor], + unneeded_shards: Dict[_ShardId, ShardedTensor], + loaded_tensors: Dict[_ShardId, torch.Tensor], +) -> Tuple[torch.Tensor, Optional[torch.device]]: + """Determines the empty tensor to use for exchange. + + If shard_id is needed by this rank, it will be in the `unloaded_shards`. + Otherwise, the metadata for this tensor can be found in `shard_to_metadata` + + Args: + shard_id (_ShardId): shard_id that will be exchanged + needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids + to metadata for shards needed by this rank + unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids + to metadata for shards that can be discarded after exchange + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors + are placed in + + Returns: + Tuple[torch.Tensor, Optional[torch.device]]: empty CUDA tensor to be exchanged, + and the device of the original state dict tensor (if there was any) + """ + local_unloaded_sh_ten = needed_shards.get(shard_id) + if local_unloaded_sh_ten is None: + orig_device = None # this tensor will be discarded anyway + sh_ten = unneeded_shards[shard_id] + if sh_ten.data is None: + sh_ten.init_data('cuda') + tensor = sh_ten.data + sh_ten.data = None # won't be used. free memory + else: + tensor = sh_ten.data + if tensor.device.type == 'cpu': + tensor = torch.empty_like(tensor, device='cuda') + else: + local_unloaded_sh_ten.init_data('cuda') + orig_device = local_unloaded_sh_ten.data.device + tensor = local_unloaded_sh_ten.data + if tensor.device.type == 'cpu': + tensor = torch.empty_like(tensor, device='cuda') + loaded_tensors[shard_id] = tensor + return tensor, orig_device + + +T = TypeVar('T') + + +def distribute_shards_to_ranks( + shard_to_ranks: Dict[T, List[int]], shard_to_size: Dict[T, int], num_ranks: int +) -> Dict[T, int]: + """Computes uniform distribution of workload across ranks, based on sizes. + + Currently, the assignment is greedy, based on: + 1. Firstly, the coverage of each shard + (how many ranks the shard is available on; lower coverage is assigned first) + 2. Secondly, the size of each shard (larger size is assigned first) + 3. Finally, shard id for differentiation. + + Third step is added because we rely on the fact that + the assignment is deterministic on all ranks. + + Args: + shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards + shard_to_size (Dict[T, int]): sizes of each shard + num_ranks (int): number of ranks in the parallelization group + + Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work + to achieve maximal uniformity) + """ + shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()} + shard_to_saving_rank = {} + rank_sizes = [(0, rank) for rank in range(num_ranks)] + + # start from tensors of lowest coverage, then go by tensor size from largest (hence minus size) + for shard_id, shard_ranks in sorted( + shard_to_ranks.items(), + key=lambda sh_id_ranks: ( + len(sh_id_ranks[1]), + -shard_to_size[sh_id_ranks[0]], + sh_id_ranks[0], + ), + ): + # assign greedily to the least occupied rank + size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks) + + shard_to_saving_rank[shard_id] = rank + rank_sizes[rank] = (size + shard_to_size[shard_id], rank) + + logger.debug(f'distribute_shards_to_ranks distribution: {rank_sizes}') + + return shard_to_saving_rank + + +def determine_main_replica_uniform_distribution( + sharded_state_dict: ShardedStateDict, + parallelization_group: torch.distributed.ProcessGroup, + ignore_groups: bool = False, +) -> Optional[ShardDistribution]: + """Computes the save distribution. + + Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution` + which applies the computed save distribution. + + We rely on the fact that the assignment algorithm is deterministic on all ranks, + so there is no extra communication needed after metadata exchange. + + Args: + sharded_state_dict (ShardedStateDict): state dict to compute the distribution of + parallelization_group (ProcessGroup): distribution will be computed + within this process group + ignore_groups (bool, optional): whether the distribution defines groups. + This option is primarily used during loading, as it ensures that all replicas, + including non-main ones, are loaded by this parallelization group + Defaults to False. + + Returns (ShardDistribution, optional): distribution that can be used to apply the + parallelization. Returns None if the process_group is trivial (1 rank) + + """ + group_size = torch.distributed.get_world_size(group=parallelization_group) + if group_size <= 1: + return + local_shards = list( + sh_base + for sh_base in nested_values(sharded_state_dict) + if isinstance(sh_base, ShardedTensor) + ) + local_shards_no_data = [ten.without_data() for ten in local_shards] + + all_shards = [None] * torch.distributed.get_world_size(group=parallelization_group) + torch.distributed.all_gather_object( + all_shards, local_shards_no_data, group=parallelization_group + ) + + shard_to_ranks = defaultdict(list) + shard_to_size = {} + shard_to_metadata = {} + shards_in_this_parallelization_group: Set[_ShardId] = set() + for rank, rank_shards in enumerate(all_shards): + for sh_ten in rank_shards: + shard_id = _sharded_tensor_shard_id(sh_ten) + shard_to_ranks[shard_id].append(rank) + if shard_id not in shard_to_size: + shard_to_size[shard_id] = _shard_size(sh_ten) + shard_to_metadata[shard_id] = sh_ten + if is_main_replica(sh_ten.replica_id) or ignore_groups: + shards_in_this_parallelization_group.add(shard_id) + + shard_to_ranks = { + k: v for k, v in shard_to_ranks.items() if k in shards_in_this_parallelization_group + } + + shard_to_saving_rank = distribute_shards_to_ranks( + shard_to_ranks, shard_to_size, len(all_shards) + ) + + return ShardDistribution( + shard_to_saving_rank, + shards_in_this_parallelization_group, + shard_to_metadata, + shard_to_ranks, + ) + + +@torch.no_grad() +@debug_time(f"exchange_loaded_tensors_gather_rounds", logger) +def exchange_loaded_tensors_gather_rounds( + loaded_tensors: Dict[_ShardId, torch.Tensor], + unloaded_shards: Dict[_ShardId, ShardedTensor], + shard_distribution: ShardDistribution = None, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, +) -> Dict[_ShardId, torch.Tensor]: + """Exchange the tensors loaded by different ranks with several all_gather calls. + + Groups tensors by dtype, divide tensors that will be exchanged into rounds + and execute all_gather for tensors from each round. + + Note: the loading is distributed across ranks based on total loaded size + in bytes, so there is no guarantee that number of rounds needed for each + rank will be similar, which might result in a lot of almost empty + all_gathers. The solution would be to group all tensors into a one + bytes tensor and do a single all_gather (with similarly sized messages). + + Args: + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to tensors already loaded by this rank. + unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to ShardedTensors that aren't loaded yet. + shard_distribution (ShardDistribution): distribution of all shards + parallelization_group (ProcessGroup, optional): process group used for load + distribution. Tensors will be exchanged within this group + + Returns: + Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors + needed by this rank to load a given state dict. Includes + previously loaded tensors (from `loaded_tensors` input) + """ + main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution + local_rank = torch.distributed.get_rank(group=parallelization_group) + + all_loaded_tensors = dict(loaded_tensors) + + # Group by dtype so that we all_gather tensors of the same dtype + for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str): + + with debug_time(f"dtype_{dtype}"): + # shards_by_rank maps rank to tensors loaded by this rank + shards_by_rank: List[List[torch.Tensor]] = [ + [] for _ in range(torch.distributed.get_world_size(group=parallelization_group)) + ] + for shard_id, rank in main_rank_for_shard.items(): + if len(all_ranks_for_shard[shard_id]) == 1: + assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], ( + f'When there is only 1 ranks that needs a given shard,' + f' it should be the loading rank.' + f' Got: needs [{all_ranks_for_shard[shard_id][0]}]' + f' vs loads [{main_rank_for_shard[shard_id]}]' + ) + # Skipping the exchange since only the loading rank needs this tensor + # TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` + # case, e.g. P2P exchange. Currently handling this case saves most of the + # work though. + continue + if shard_to_metadata[shard_id].dtype == dtype: + shards_by_rank[rank].append(shard_id) + + # Transpose `shards_by_rank` to form exchange rounds + shards_by_round = zip_longest(*shards_by_rank, fillvalue=None) + for round_idx, round_shard_ids in enumerate(shards_by_round): + round_tensors = [] + orig_devices = {} + for rank, shard_id in enumerate(round_shard_ids): + if shard_id is None: + # if no more useful data, the given rank will exchange empty tensor + local_ten = torch.empty(0, dtype=dtype, device='cuda') + orig_device = None + else: + assert isinstance(shard_id, tuple), type(shard_id) + if rank == local_rank: + assert shard_id in all_loaded_tensors, ( + shard_id, + all_loaded_tensors.keys(), + ) + orig_device = all_loaded_tensors[shard_id] + all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda() + local_ten = all_loaded_tensors[shard_id] + else: + local_ten, orig_device = _get_empty_tensor_for_exchange( + shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors + ) + # Because of a TE bug, we have to exchange a nominal dtype instead of FP8 + # It's ok to keep the nominal dtype after exchange, because TE will handle + # this during state dict load. + # TODO: remove it once the bug is fixed + if is_float8tensor(local_ten): + local_ten = local_ten.from_float8() + all_loaded_tensors[shard_id] = local_ten + + round_tensors.append(local_ten) + if orig_device is not None: + orig_devices[shard_id] = orig_device + + torch.distributed.all_gather( + list(round_tensors), + round_tensors[local_rank], + group=parallelization_group, + async_op=False, + ) + + # Move tensors back to CPU if originally was on CPU + for shard_id, orig_device in orig_devices.items(): + all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device) + + del round_tensors # remove tensor references + + return all_loaded_tensors + + +def exchange_loaded_tensors_gather_object( + loaded_tensors: Dict[_ShardId, torch.Tensor], + unloaded_shards: Dict[_ShardId, ShardedTensor], + shard_distribution: ShardDistribution, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, +) -> Dict[_ShardId, torch.Tensor]: + """Exchange the tensors loaded by different ranks with a simple all_gather_object call. + + This version can be used for debugging purposes do to its simplistic + implementation. Shouldn't be used if performance is important. + + Args: + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to tensors already loaded by this rank. + unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to ShardedTensors that aren't loaded yet. + shard_distribution (ShardDistribution): distribution of all shards + parallelization_group (ProcessGroup, optional): process group used for load + distribution. Tensors will be exchanged within this group + + Returns: + Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors + needed by this rank to load a given state dict. Includes + previously loaded tensors (from `loaded_tensors` input) + + """ + all_loaded_tensors_list = [None] * torch.distributed.get_world_size(group=parallelization_group) + torch.distributed.all_gather_object( + all_loaded_tensors_list, loaded_tensors, group=parallelization_group + ) + all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list) + all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list) + + # Error checks + if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)): + err_msg = 'Duplicate shard ids loaded by different ranks' + if torch.distributed.get_rank() == 0: + logger.error( + f'{err_msg}. Shards ids by rank:' + f' {[lt.keys() for lt in all_loaded_tensors_list]}' + ) + raise CheckpointingException(err_msg) + + return all_loaded_tensors + + +def exchange_loaded_objects_gather_object( + loaded_objects: Dict[_ShardId, Any] +) -> Dict[_ShardId, Any]: + """Exchange the objects loaded by different ranks with a simple all_gather_object call. + + Args: + loaded_objects (Dict[_ShardId, Any]): mapping from shard ids to objects + already loaded by this rank. + + Returns: + Dict[_ShardId, Any]: dictionary mapping shard ids to objects needed by this rank to + load a given state dict. + """ + all_loaded_objects_list = [None] * torch.distributed.get_world_size(group=None) + torch.distributed.all_gather_object(all_loaded_objects_list, loaded_objects, group=None) + all_loaded_objects_list = cast(List[Dict[_ShardId, Any]], all_loaded_objects_list) + all_loaded_objects = reduce(lambda x, y: {**x, **y}, all_loaded_objects_list) + + # Error checks + if len(all_loaded_objects) != sum(map(len, all_loaded_objects_list)): + err_msg = 'Duplicate shard ids loaded by different ranks' + if torch.distributed.get_rank() == 0: + logger.error( + f'{err_msg}. Shards ids by rank:' + f' {[lt.keys() for lt in all_loaded_objects_list]}' + ) + raise CheckpointingException(err_msg) + + return all_loaded_objects + + +@torch.no_grad() +@debug_time("exchange_loaded_tensors_broadcast", logger) +def exchange_loaded_tensors_broadcast( + loaded_tensors: Dict[_ShardId, torch.Tensor], + unloaded_shards: Dict[_ShardId, ShardedTensor], + shard_distribution: ShardDistribution, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, +) -> Dict[_ShardId, torch.Tensor]: + """Exchange the tensors loaded by different ranks by a series of broadcasts. + + For each rank for each loaded tensor do a broadcast to the whole group. + A reasonable tradeoff in terms of performance and simplicity. + + Args: + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to tensors already loaded by this rank. + unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor + shard ids to ShardedTensors that aren't loaded yet. + shard_distribution (ShardDistribution): distribution of all shards + parallelization_group (ProcessGroup, optional): process group used for load + distribution. Tensors will be exchanged within this group + + Returns: + Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors + needed by this rank to load a given state dict. Includes + previously loaded tensors (from `loaded_tensors` input) + """ + main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution + local_rank = torch.distributed.get_rank(group=parallelization_group) + + all_loaded_tensors = dict(loaded_tensors) + + for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()): + if len(all_ranks_for_shard[shard_id]) == 1: + assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], ( + f'When there is only 1 ranks that needs a given shard,' + f' it should be the loading rank.' + f'Got: needs [{all_ranks_for_shard[shard_id][0]}]' + f' vs loads [{main_rank_for_shard[shard_id]}]' + ) + # Skipping the exchange since only the loading rank needs this tensor + # TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` case, + # e.g. P2P exchange. Currently handling this case saves most of the work though. + continue + if rank == local_rank: + assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys()) + orig_device = all_loaded_tensors[shard_id].device + local_ten = all_loaded_tensors[shard_id].cuda() + else: + local_ten, orig_device = _get_empty_tensor_for_exchange( + shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors + ) + + # Because of a TE bug, we have to exchange a nominal dtype instead of FP8 + # It's ok to keep the nominal dtype after exchange, because TE will handle + # this during state dict load. + # TODO: remove it once the bug is fixed + if is_float8tensor(local_ten): + local_ten = local_ten.from_float8() + all_loaded_tensors[shard_id] = local_ten + + global_src_rank = ( + rank + if parallelization_group == None + else torch.distributed.get_global_rank(parallelization_group, rank) + ) + # We can do async_op=True only if there is no CPU-copy follow-up + torch.distributed.broadcast( + local_ten, + src=global_src_rank, + group=parallelization_group, + async_op=orig_device is None, + ) + # Move tensor back to CPU if originally was on CPU + if orig_device is not None: + all_loaded_tensors[shard_id] = local_ten.to(orig_device) + del local_ten + + return all_loaded_tensors + + +def exchange_by_distribution( + loaded_tensors: Dict[_ShardId, torch.Tensor], + unloaded_shards: Dict[_ShardId, ShardedTensor], + shard_distribution: ShardDistribution, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + exchange_algo='broadcast', +) -> Dict[_ShardId, torch.Tensor]: + """Exchange tensors loaded by different ranks using the specified exchange_algo. + + Args: + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to tensors already loaded by this rank. + unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor + shard ids to ShardedTensors that aren't loaded yet. + shard_distribution (ShardDistribution): distribution of all shards + parallelization_group (ProcessGroup, optional): process group used for load + distribution. Tensors will be exchanged within this group + exchange_algo (str): The algorithm used for performing exchanges. + Defaults to 'broadcast'. + + Returns: + Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors + needed by this rank to load a given state dict. Includes + previously loaded tensors (from `loaded_tensors` input) + """ + + assert shard_distribution is not None, 'Expecting distribution to perform exchange' + if exchange_algo == 'gather_object': + exchange_fn = exchange_loaded_tensors_gather_object + elif exchange_algo == 'gather_rounds': + exchange_fn = exchange_loaded_tensors_gather_rounds + elif exchange_algo == 'broadcast': + exchange_fn = exchange_loaded_tensors_broadcast + else: + raise NotImplementedError(f'Unrecognized gather algorithm: {exchange_algo}') + return exchange_fn(loaded_tensors, unloaded_shards, shard_distribution, parallelization_group) diff --git a/megatron/core/dist_checkpointing/mapping.py b/megatron/core/dist_checkpointing/mapping.py index d376c6374baf7053b549f0686713e0e1c672d7b2..6bf35324d2d2867430e17e2fdae293236e32724c 100644 --- a/megatron/core/dist_checkpointing/mapping.py +++ b/megatron/core/dist_checkpointing/mapping.py @@ -1,723 +1,725 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Core library classes for representing sharding of tensors and objects. - -The main expected usage is wrapping torch.Tensors in state dicts with -ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod). -""" - -import logging -from abc import ABC, abstractmethod -from dataclasses import dataclass, field, replace -from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch - -from .core import CheckpointingException -from .dict_utils import dict_list_map_inplace - -logger = logging.getLogger(__name__) - -# These type definitions are just hints to differentiate a plain model state -# dict (StateDict) from a state dict with tensors replaced with ShardedTensors -# (ShardedStateDict). -StateDict = Dict[str, Any] -CommonStateDict = Dict[str, Any] -ShardedStateDict = Dict[str, Any] -ReplicaId = Union[int, Tuple[int, ...]] - - -class ShardedBase(ABC): - """Base class for ShardedTensor and ShardedStateDict.""" - - key: str - data: object - replica_id: ReplicaId - - @abstractmethod - def validate_metadata_integrity(self): - """Codifies the constraints on metadata attributes.""" - - @abstractmethod - def without_data(self) -> 'ShardedBase': - """Returns a new ShardedBase instance with data=None.""" - raise NotImplementedError - - -@dataclass -class ShardedTensor(ShardedBase): - """Represents a mapping between a local tensor and a global tensor. - - Global tensor is assumed to consist of many local tensors distributed - between different processes. - - Args: - key: unique identifier of a global tensor - data: local tensor data. Can be None only for consistency validation - dtype: tensor dtype - local_shape: local tensor shape - global_shape: global tensor shape - global_offset: offset of a local tensor in a global tensor, - specified in number of tensor elements - axis_fragmentations: global tensor fragmentation of each axis - replica_id: indicates given local tensor's replication wrt. - local tensors in different processes - prepend_axis_num: number of axes prepended to the local tensor to - reflect global tensor shape. The behavior is similar to - unsqueezing the local tensor. - allow_shape_mismatch: if True, during loading, the global shape of - a stored tensor does not have to match the expected global shape. - Useful for representing tensors with flexible shape, - e.g. padded. - flattened_range: specifies a slice that should be applied to a - flattened tensor with `local_shape` in order to get - the tensor stored as `data` - """ - - key: str - data: Optional[torch.Tensor] = field(repr=False) - dtype: torch.dtype - local_shape: Tuple[int, ...] - global_shape: Tuple[int, ...] - global_offset: Tuple[int, ...] - axis_fragmentations: Optional[Tuple[int, ...]] - replica_id: ReplicaId = 0 - prepend_axis_num: int = 0 - allow_shape_mismatch: bool = False - flattened_range: Optional[slice] = None - - def __post_init__(self): - self.validate_metadata_integrity() - - def validate_metadata_integrity(self) -> None: - """Codifies the constraints on metadata attributes. - - Meeting those constraints is guaranteed when instantiating a ShardedTensor - class with `from_rank_offsets` or `from_rank_offsets_flat` constructors. - - Returns: - None - """ - has_flattened_range = self.flattened_range is not None - if self.data is not None: - if self.data.dtype != self.dtype: - raise CheckpointingException( - f'Data dtype should match `dtype` attribute for {self}' - ) - if not has_flattened_range and self.data.shape != self.local_shape: - raise CheckpointingException( - f'Data shape should match `local_shape` attribute for {self}' - ) - if has_flattened_range: - if self.data.ndim != 1: - raise CheckpointingException(f'Data should be 1D for a flattened {self}') - real_data = self.data - try: - self.data = None - self.init_data(device='meta') - if self.data.shape != real_data.shape: - raise CheckpointingException( - f'Data shape {real_data.shape} doesnt match' - f' expected {self.data.shape} for {self}' - ) - finally: - self.data = real_data - - if len(self.global_shape) != len(self.global_offset): - raise CheckpointingException( - f'Global offset dimensions should be equal to global shape dimensions for {self}' - ) - if len(self.local_shape) + self.prepend_axis_num != len(self.global_shape): - raise CheckpointingException( - f'Local shape together with `prepend_axis_num` dimensions should be ' - f'equal to global shape dimensions for {self}' - ) - - for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape): - if off % sh != 0: - raise CheckpointingException( - f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.' - ) - - if has_flattened_range and self.flattened_range.step is not None: - raise CheckpointingException( - f'`step` argument in the flattened range of a ShardedTensor is not supported.' - ) - - def global_slice(self) -> Tuple[Union[int, slice], ...]: - """ - Returns a tuple of int and slice objects representing a slice of the - global tensor that this ShardedTensor corresponds to. - """ - assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num - return tuple( - chain( - (off for off in self.global_offset[: self.prepend_axis_num]), - ( - slice(off, off + sh) - for off, sh in zip( - self.global_offset[self.prepend_axis_num :], self.local_shape - ) - ), - ) - ) - - def global_coordinates(self) -> Tuple[np.ndarray, ...]: - """ - Returns a tuple of np.ndarrays representing the coordinates of the global tensor - that this ShardedTensor corresponds to. - """ - if self.flattened_range is None: - raise CheckpointingException( - f'`global_coordinates` is undefined for' - f' {self.__class__.__name__} without `flattened_range`' - ) - - local_coords = self.local_coordinates() - assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), ( - len(local_coords), - self, - ) - global_coords = tuple( - c + off - for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset) - ) - return global_coords - - def local_coordinates(self) -> Tuple[np.ndarray, ...]: - """ - Returns a tuple of np.ndarrays representing the coordinates of the local tensor - that this ShardedTensor corresponds to. - """ - if self.flattened_range is None: - raise CheckpointingException( - f'`local_coordinates` is undefined for' - f' {self.__class__.__name__} without `flattened_range`' - ) - - # TODO: np.unravel_index? - mask = np.zeros(np.product(self.local_shape), dtype=bool) - mask[self.flattened_range] = True - return np.nonzero(mask.reshape(self.local_shape)) - - def local_chunk_offset_in_global(self) -> Tuple[int, ...]: - """Offset of a local chunk in a global array of chunks. - - Returns: - Tuple[int, ...]: the offset of the whole local chunk in a global array of chunks. - """ - assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num - chunk_offset = list(self.global_offset[: self.prepend_axis_num]) - for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape): - assert off % sh == 0, str(self) - chunk_offset.append(off // sh) - return tuple(chunk_offset) - - def max_allowed_chunks(self) -> Tuple[int, ...]: - """ - Returns the maximum allowed chunks for this ShardedTensor. - """ - chunks = [] - for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations): - if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0: - raise CheckpointingException( - f'Axis shape ({axis_sh}) not divisible by axis fragmentation ({axis_fragm}' - ) - axis_chunk_size = axis_sh // axis_fragm - chunks.append(axis_chunk_size) - return tuple(chunks) - - def without_data(self): - return replace(self, data=None) - - @classmethod - def from_rank_offsets( - cls, - key: str, - data: torch.Tensor, - *rank_offsets: Tuple[int, int, int], - replica_id: ReplicaId = 0, - prepend_axis_num: int = 0, - flattened_range: None = None, - **init_kwargs, - ): - """Allows to construct the ShardedTensor given offset specified in process ranks. - - Args: - key (str): unique key - data (torch.Tensor): local tensor data - rank_offsets (Tuple[int, int, int]): each tuple - (axis, axis_rank_offset, axis_fragm) says that if - global tensor is divided into `axis_fragm` fragment along `axis` - axis, then local tensor data corresponds to the `axis_rank_offset` chunk. - replica_id (ReplicaId): see ShardedTensor - prepend_axis_num (int): see ShardedTensor - flattened_range (None): must be None when using this constructor - init_kwargs: passed to ShardedTensor.__init__ - """ - if flattened_range is not None: - raise ValueError( - 'Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method.' - ' Use `from_rank_offsets_flat` instead' - ) - global_offset = [0] * (data.ndim + prepend_axis_num) - global_shape = ([1] * prepend_axis_num) + list(data.shape) - axis_fragmentations = [1] * (data.ndim + prepend_axis_num) - _seen_axis = set() - for axis, axis_rank_offset, axis_fragm in rank_offsets: - if axis < 0 or axis_rank_offset < 0 or axis_fragm < 1 or axis_rank_offset >= axis_fragm: - raise CheckpointingException(f'Invalid rank offsets: {rank_offsets} for key {key}.') - _seen_axis.add(axis) - - local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num] - global_shape[axis] = axis_fragm * local_axis_shape - global_offset[axis] = axis_rank_offset * local_axis_shape - axis_fragmentations[axis] = axis_fragm - - return cls( - key, - data, - data.dtype, - tuple(data.shape), - tuple(global_shape), - tuple(global_offset), - tuple(axis_fragmentations), - replica_id, - prepend_axis_num, - flattened_range=flattened_range, - **init_kwargs, - ) - - @classmethod - def from_rank_offsets_flat( - cls, - key: str, - data: torch.Tensor, - non_flat_local_shape: Tuple[int, ...], - *args, - flattened_range: Optional[slice] = None, - **kwargs, - ): - """Allows to construct a *flattened* ShardedTensor given offset specified in process ranks. - - Args: - key (str): - data (torch.Tensor): this should be a flattened data tensor - non_flat_local_shape (Tuple[int, ...]): expected local shape of a non-flat chunk - *args: passed unchanged to the `from_rank_offsets` constructor - flattened_range (slice): see ShardedTensor. Defaults to None, but must be set to - a non-None slice. - **kwargs: - - Returns: - ShardedTensor: constructed ShardedTensor instance - """ - if flattened_range is None: - raise CheckpointingException( - 'Cannot instantiate a non-flat ShardedTensor with `from_rank_offsets_flat` method.' - ' Use `from_rank_offsets` instead' - ) - if data.ndim != 1: - raise CheckpointingException( - f'Flattened ShardedTensor requires 1D data, got shape: {data.shape}' - ) - if flattened_range.stop - flattened_range.start != data.numel(): - raise CheckpointingException( - f'Flattened ShardedTensor data length ({data.numel()}) must meet the ' - f'slice length: {flattened_range.stop - flattened_range.start}' - ) - - non_flat_data_meta = torch.empty(*non_flat_local_shape, dtype=data.dtype, device='meta') - sh_ten = cls.from_rank_offsets(key, non_flat_data_meta, *args, **kwargs) - instance = replace(sh_ten, data=data, flattened_range=flattened_range) - instance.validate_metadata_integrity() - return instance - - def init_data(self, device: Union[str, torch.device], init_fn=torch.empty): - """ - Initialize the tensor data of this ShardedTensor. - - Only called if `data` attribute is None. - - Args: - device (Union[str, torch.device]): device to place the tensor on - init_fn (Callable, optional): function to use to initialize the tensor. - Defaults to `torch.empty`. - """ - if self.data is not None: - return - self.data = init_fn(self.local_shape, dtype=self.dtype, device=device) - if self.flattened_range is not None: - self.data = self.data.flatten()[self.flattened_range.start : self.flattened_range.stop] - - def narrow(self, dim: int, start: int, length: int) -> List['ShardedTensor']: - """This is an analogue of torch.narrow for ShardedTensors. - - Narrowing assumes that we narrow a local tensor on each rank. - This has consequences on local_shape, global_shape, global_offset, etc. - - Args: - dim (int): dimension to narrow. Doesn't include prepended axes. - start (int): start element - length (int): length of the slice - - Returns: - List[ShardedTensor]: narrowed ShardedTensors. For non-flat tensors, - the list will always have 1 element. For flat ShardedTensors the number of - elements varies depending on `dim` and on overlap, because flat - tensors must be contiguous. In particular the list can be empty. - """ - prepended_dim = dim + self.prepend_axis_num - local_length_along_dim = self.local_shape[dim] - - def _update_tuple(x, ind, val): - x = list(x) - x[ind] = val - return tuple(x) - - def _safe_div(x, y): - assert x % y == 0, (x, y) - return x // y - - # Decrease global shape and global offset by `length / local_length_along_dim` - assert ( - self.global_shape[prepended_dim] % local_length_along_dim == 0 - ), f'Only regular grid of local tensors is supported for narrowing, got: {self}' - assert ( - self.global_offset[prepended_dim] % local_length_along_dim == 0 - ), f'Only regular grid of local tensors is supported for narrowing, got: {self}' - global_shape = _update_tuple( - self.global_shape, - prepended_dim, - _safe_div(self.global_shape[prepended_dim] * length, local_length_along_dim), - ) - global_offset = _update_tuple( - self.global_offset, - prepended_dim, - _safe_div(self.global_offset[prepended_dim] * length, local_length_along_dim), - ) - - if self.flattened_range is None: - new_data = self.data.narrow(dim, start, length) - # always a single result tensor - return [ - replace( - self, - data=new_data, - local_shape=new_data.shape, - global_shape=global_shape, - global_offset=global_offset, - ) - ] - else: - if dim != 0: - raise CheckpointingException( - f'Narrowing along the first axis is supported for now only, got dim={dim}' - ) - - # If dim=0, we will always get 0 or 1 resulting tensor. - # If dim>1, in general there can be more result tensors (e.g. max 3 for dim=1) - - # For on original flat ShardedTensor of local shape [3, 4] and - # flattened_range=slice(5, 10), - # the X signs mark the actual (flat) data in `self.data` - # notice 12 (3*4) total "virtual" elements, out of which 5 is actual data. - # flat original: [.....XXXXX..] - - # If we narrow to start=1, length=1 in the original local shape dimensions, - # the overlapping flat slice would be: - # narrow to: [....XXXX....] - # flat overlap: [.....XXX....] - - # Now `data` is flattened and sliced, so we must compute local_shape manually - local_shape = _update_tuple(self.local_shape, dim, length) - other_dims_volume = np.prod( - _update_tuple(local_shape, dim, 1) - ) # 4 in the example above - volume_before_split = other_dims_volume * start # 4 in the example above - volume_of_split = other_dims_volume * length # 4 in the example above - - flat_slice_start_shifted = ( - self.flattened_range.start - volume_before_split - ) # 5 - 4 = 1 in the example above - flat_slice_stop_shifted = ( - self.flattened_range.stop - volume_before_split - ) # 10 - 4 = 6 in the example above - - # Find an intersection of - # (flat_slice_start_shifted, flat_slice_stop_shifted) vs (0, volume_of_split) - - if flat_slice_stop_shifted <= 0 or flat_slice_start_shifted >= volume_of_split: - return [] # no intersection - - # new_flattened_range = slice(1, 4) in the example above - new_flattened_range = slice( - max(flat_slice_start_shifted, 0), min(flat_slice_stop_shifted, volume_of_split) - ) - # Apply the intersection to the flattened data tensor. - # Compute start and slice appropriate length - intersection_slice_start = ( - new_flattened_range.start - flat_slice_start_shifted - ) # 0 in the example above - new_data = self.data[ - intersection_slice_start : intersection_slice_start - + new_flattened_range.stop - - new_flattened_range.start - ] - - return [ - replace( - self, - data=new_data, - local_shape=local_shape, - global_shape=global_shape, - global_offset=global_offset, - flattened_range=new_flattened_range, - ) - ] - - -def is_main_replica(replica_id: ReplicaId): - """Checks if given `replica_id` is considered as main. - - "Main" replica is: - - integer 0 - - or an iterable with all 0 elements - - It is the application responsibility to set correct replicas for sharded tensors. - - Args: - replica_id (Union[int, Tuple[int, ...]]): replica id - - Returns: - (bool): True for a "main" replica - """ - if isinstance(replica_id, int): - return replica_id == 0 - return all(r == 0 for r in replica_id) - - -class LocalNonpersistentObject: - """Object that should not be stored in a checkpoint, but restored locally. - - Wrapping any object inside the state dict with LocalNonpersistentObject - will result in: - - during saving, this object will *not* be stored in the checkpoint - - during loading, a local version of this object will be placed in a state dict - """ - - def __init__(self, obj): - self.obj = obj - - def unwrap(self): - """Returns the original object.""" - return self.obj - - -# TODO: Delete once NeMo fixes typo. -LocalNonpersitentObject = LocalNonpersistentObject - - -@dataclass -class ShardedObject(ShardedBase): - """Represents a mapping between a local object and a global object. - - Global object is assumed to consist of many local objects distributed - between different processes. - - NOTE: Contrary to ShardedTensor, it's impossible to change global object - sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor - with atomic arbitrary typed elements. - - Args: - key: unique identifier of a global tensor - data: local object data. Can be None only for consistency validation - global_shape: global object shape - global_offset: offset of a local object in a global object, specified in number of shards - replica_id: indicates local object replication wrt. local objects in different processes - """ - - key: str - data: object - global_shape: Tuple[int, ...] - global_offset: Tuple[int, ...] - replica_id: ReplicaId = 0 - - def __post_init__(self): - self.validate_metadata_integrity() - - def validate_metadata_integrity(self): - if len(self.global_shape) != len(self.global_offset): - raise CheckpointingException( - f'Global offset dimensions should be equal to global shape dimensions for {self}' - ) - - def without_data(self): - return replace(self, data=None) - - @property - def unique_key(self): - """returns a unique key for this object""" - return ( - f'{self.key}/shard_' - f'{".".join(map(str, self.global_offset))}_' - f'{".".join(map(str, self.global_shape))}' - ) - - def __str__(self): - return f'{self.__class__.__name__}(key=\'{self.key}\')' - - @classmethod - def empty_from_unique_key(cls, unique_key, replica_id: ReplicaId = 0) -> 'ShardedObject': - """Instantiates a ShardedObject from a unique key. - - Args: - unique_key: a string of the form - /shard__ - replica_id: indicates local object replication wrt. - local objects in different processes - - Returns: - a ShardedObject with data=None - """ - key, shard_key = unique_key.split('/') - shard_str, offset, shape = shard_key.split('_') - assert shard_str == 'shard' - offset = tuple(map(int, offset.split('.'))) - shape = tuple(map(int, shape.split('.'))) - if len(shape) + 1 == len(offset): - # This is a backward-compatible fix. We don't know the last - # element of global shape so set it to -1. - shape += (-1,) - return cls(key, None, shape, offset, replica_id) - - -FactoryBuildFn = Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict] -FactoryMergeFn = Callable[[StateDict], torch.Tensor] - - -@dataclass -class ShardedTensorFactory(ShardedBase): - """Allows to apply transformations to tensors before/after serialization. - - The essence of those transformations is that they can be applied to - optimizer states the same way they are applied to the model params. - The ultimate state dict with sharded tensors must depend functionally on - `build_fn` arguments (key, data, replica_id, flattened_range), - which will be provided by the optimizer. - - Builder creates a sub-state-dict out of a tensor before saving, and merger - merges the corresponding state dict after loading. - - Args: - key (str): unique identifier of the factory - data (torch.Tensor): original model parameter that will be further - transformed by this factory - build_fn (callable): function that transforms the original tensor - to a sharded state dict - merge_fn (callable): function that transforms loaded subtree back - into a single tensor (inverse of `build_fn`) - replica_id (ReplicaId): indicates factory replication wrt. - factories in different processes - flattened_range (slice, optional): indicates additional flattening - applied to the ShardedTensors produced by the factory - """ - - key: str - data: torch.Tensor - build_fn: FactoryBuildFn - merge_fn: FactoryMergeFn - replica_id: ReplicaId = 0 - flattened_range: Optional[slice] = None - - def build(self): - """Builds a ShardedStateDict from the original tensor""" - return self.build_fn(self.key, self.data, self.replica_id, self.flattened_range) - - def validate_metadata_integrity(self): - """No reasonable checks can be applied""" - pass - - def without_data(self): - return replace(self, data=None) - - -def apply_factories(sharded_state_dict: ShardedStateDict): - """Turn ShardedTensorFactories into ShardedTensors *in-place*. - - Args: - sharded_state_dict (ShardedStateDict): state dict possibly - containing ShardedTensorFactory objects - - Returns: - None: state dict is modified in place - """ - - def apply(x): - if isinstance(x, ShardedTensorFactory): - x = x.build() - return x - - dict_list_map_inplace(apply, sharded_state_dict) - - -def apply_factory_merges( - x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = () -) -> StateDict: - """Apply merges defined by ShardedTensorFactories *in-place*. - - Args: - x1 (StateDict): state dict loaded from the checkpoint - x2 (ShardedStateDict): subset of `x1` (in terms of dict keys) - with ShardedTensorFactory - as (possibly nested) values that define how to - merge objects from the `x1` state dict - key (Tuple[str, ...]): current key in a recursive call. - Used only for reporting meaningful errors - - Returns: - StateDict: `x1` modified in-place - """ - if isinstance(x2, ShardedTensorFactory): - return x2.merge_fn(x1) - - # There rest is almost the same as the `merge` function from `dict_utils` - if isinstance(x1, dict) and isinstance(x2, dict): - for k, v2 in x2.items(): - if k not in x1: - raise ValueError( - f'Different dict keys encountered in `apply_factory_merges` ' - f'({x1.keys()} vs {x2.keys()})' - ) - else: - x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,)) - elif isinstance(x1, list) and isinstance(x2, list): - if len(x1) != len(x2): - err_msg = ( - f'Cannot merge two lists with different lengths ' - f'({len(x1)} and {len(x2)}, encountered at key {key})' - ) - logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}') - raise ValueError(err_msg) - for i, v2 in enumerate(x2): - x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,)) - elif isinstance(x1, list) and isinstance(x2, dict): - for k, v2 in x2.items(): - if not isinstance(k, int): - raise ValueError( - f'Invalid dict key {k} non-integer type encountered ' - f'in a list-dict merge at level {key}' - ) - if k >= len(x1): - raise ValueError( - f'Dict key {k} out of bound for list of length' - f'{len(x1)} (encountered at level {key})' - ) - x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,)) - else: - raise ValueError( - f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`' - ) - return x1 +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Core library classes for representing sharding of tensors and objects. + +The main expected usage is wrapping torch.Tensors in state dicts with +ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod). +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, replace +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from .core import CheckpointingException +from .dict_utils import dict_list_map_inplace + +logger = logging.getLogger(__name__) + +# These type definitions are just hints to differentiate a plain model state +# dict (StateDict) from a state dict with tensors replaced with ShardedTensors +# (ShardedStateDict). +StateDict = Dict[str, Any] +CommonStateDict = Dict[str, Any] +ShardedStateDict = Dict[str, Any] +ReplicaId = Union[int, Tuple[int, ...]] + + +class ShardedBase(ABC): + """Base class for ShardedTensor and ShardedStateDict.""" + + key: str + data: object + replica_id: ReplicaId + + @abstractmethod + def validate_metadata_integrity(self): + """Codifies the constraints on metadata attributes.""" + + @abstractmethod + def without_data(self) -> 'ShardedBase': + """Returns a new ShardedBase instance with data=None.""" + raise NotImplementedError + + +@dataclass +class ShardedTensor(ShardedBase): + """Represents a mapping between a local tensor and a global tensor. + + Global tensor is assumed to consist of many local tensors distributed + between different processes. + + Args: + key: unique identifier of a global tensor + data: local tensor data. Can be None only for consistency validation + dtype: tensor dtype + local_shape: local tensor shape + global_shape: global tensor shape + global_offset: offset of a local tensor in a global tensor, + specified in number of tensor elements + axis_fragmentations: global tensor fragmentation of each axis + replica_id: indicates given local tensor's replication wrt. + local tensors in different processes + prepend_axis_num: number of axes prepended to the local tensor to + reflect global tensor shape. The behavior is similar to + unsqueezing the local tensor. + allow_shape_mismatch: if True, during loading, the global shape of + a stored tensor does not have to match the expected global shape. + Useful for representing tensors with flexible shape, + e.g. padded. + flattened_range: specifies a slice that should be applied to a + flattened tensor with `local_shape` in order to get + the tensor stored as `data` + """ + + key: str + data: Optional[torch.Tensor] = field(repr=False) + dtype: torch.dtype + local_shape: Tuple[int, ...] + global_shape: Tuple[int, ...] + global_offset: Tuple[int, ...] + axis_fragmentations: Optional[Tuple[int, ...]] + replica_id: ReplicaId = 0 + prepend_axis_num: int = 0 + allow_shape_mismatch: bool = False + flattened_range: Optional[slice] = None + + def __post_init__(self): + self.validate_metadata_integrity() + + def validate_metadata_integrity(self) -> None: + """Codifies the constraints on metadata attributes. + + Meeting those constraints is guaranteed when instantiating a ShardedTensor + class with `from_rank_offsets` or `from_rank_offsets_flat` constructors. + + Returns: + None + """ + has_flattened_range = self.flattened_range is not None + if self.data is not None: + if self.data.dtype != self.dtype: + raise CheckpointingException( + f'Data dtype should match `dtype` attribute for {self}' + ) + if not has_flattened_range and self.data.shape != self.local_shape: + raise CheckpointingException( + f'Data shape should match `local_shape` attribute for {self}' + ) + if has_flattened_range: + if self.data.ndim != 1: + raise CheckpointingException(f'Data should be 1D for a flattened {self}') + real_data = self.data + try: + self.data = None + self.init_data(device='meta') + if self.data.shape != real_data.shape: + raise CheckpointingException( + f'Data shape {real_data.shape} doesnt match' + f' expected {self.data.shape} for {self}' + ) + finally: + self.data = real_data + + if len(self.global_shape) != len(self.global_offset): + raise CheckpointingException( + f'Global offset dimensions should be equal to global shape dimensions for {self}' + ) + if len(self.local_shape) + self.prepend_axis_num != len(self.global_shape): + raise CheckpointingException( + f'Local shape together with `prepend_axis_num` dimensions should be ' + f'equal to global shape dimensions for {self}' + ) + + for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape): + # NOTE: In custom FSDP, we have a case where a new parameter shard is created locally. + # For example, consider parameters [p0, p1, p2] sharded across GPU0 and GPU1. + # GPU0 receives p0 and a portion of p1, while GPU1 receives the + # remaining portion of p1 and p2. + # As a result, there is no parameter shard of p2 on GPU0, and + # the shape of p2 on GPU0 is zero. + if sh != 0 and off % sh != 0: + raise CheckpointingException( + f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.' + ) + + if has_flattened_range and self.flattened_range.step is not None: + raise CheckpointingException( + f'`step` argument in the flattened range of a ShardedTensor is not supported.' + ) + + def global_slice(self) -> Tuple[Union[int, slice], ...]: + """ + Returns a tuple of int and slice objects representing a slice of the + global tensor that this ShardedTensor corresponds to. + """ + assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num + return tuple( + chain( + (off for off in self.global_offset[: self.prepend_axis_num]), + ( + slice(off, off + sh) + for off, sh in zip( + self.global_offset[self.prepend_axis_num :], self.local_shape + ) + ), + ) + ) + + def global_coordinates(self) -> Tuple[np.ndarray, ...]: + """ + Returns a tuple of np.ndarrays representing the coordinates of the global tensor + that this ShardedTensor corresponds to. + """ + if self.flattened_range is None: + raise CheckpointingException( + f'`global_coordinates` is undefined for' + f' {self.__class__.__name__} without `flattened_range`' + ) + + local_coords = self.local_coordinates() + assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), ( + len(local_coords), + self, + ) + global_coords = tuple( + c + off + for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset) + ) + return global_coords + + def local_coordinates(self) -> Tuple[np.ndarray, ...]: + """ + Returns a tuple of np.ndarrays representing the coordinates of the local tensor + that this ShardedTensor corresponds to. + """ + if self.flattened_range is None: + raise CheckpointingException( + f'`local_coordinates` is undefined for' + f' {self.__class__.__name__} without `flattened_range`' + ) + + # TODO: np.unravel_index? + mask = np.zeros(np.product(self.local_shape), dtype=bool) + mask[self.flattened_range] = True + return np.nonzero(mask.reshape(self.local_shape)) + + def local_chunk_offset_in_global(self) -> Tuple[int, ...]: + """Offset of a local chunk in a global array of chunks. + + Returns: + Tuple[int, ...]: the offset of the whole local chunk in a global array of chunks. + """ + assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num + chunk_offset = list(self.global_offset[: self.prepend_axis_num]) + for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape): + assert off % sh == 0, str(self) + chunk_offset.append(off // sh) + return tuple(chunk_offset) + + def max_allowed_chunks(self) -> Tuple[int, ...]: + """ + Returns the maximum allowed chunks for this ShardedTensor. + """ + chunks = [] + for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations): + if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0: + raise CheckpointingException( + f'Axis shape ({axis_sh}) not divisible by axis fragmentation ({axis_fragm}' + ) + axis_chunk_size = axis_sh // axis_fragm + chunks.append(axis_chunk_size) + return tuple(chunks) + + def without_data(self): + return replace(self, data=None) + + @classmethod + def from_rank_offsets( + cls, + key: str, + data: torch.Tensor, + *rank_offsets: Tuple[int, int, int], + replica_id: ReplicaId = 0, + prepend_axis_num: int = 0, + flattened_range: None = None, + **init_kwargs, + ): + """Allows to construct the ShardedTensor given offset specified in process ranks. + + Args: + key (str): unique key + data (torch.Tensor): local tensor data + rank_offsets (Tuple[int, int, int]): each tuple + (axis, axis_rank_offset, axis_fragm) says that if + global tensor is divided into `axis_fragm` fragment along `axis` + axis, then local tensor data corresponds to the `axis_rank_offset` chunk. + replica_id (ReplicaId): see ShardedTensor + prepend_axis_num (int): see ShardedTensor + flattened_range (None): must be None when using this constructor + init_kwargs: passed to ShardedTensor.__init__ + """ + if flattened_range is not None: + raise ValueError( + 'Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method.' + ' Use `from_rank_offsets_flat` instead' + ) + global_offset = [0] * (data.ndim + prepend_axis_num) + global_shape = ([1] * prepend_axis_num) + list(data.shape) + axis_fragmentations = [1] * (data.ndim + prepend_axis_num) + _seen_axis = set() + for axis, axis_rank_offset, axis_fragm in rank_offsets: + if axis < 0 or axis_rank_offset < 0 or axis_fragm < 1 or axis_rank_offset >= axis_fragm: + raise CheckpointingException(f'Invalid rank offsets: {rank_offsets} for key {key}.') + _seen_axis.add(axis) + + local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num] + global_shape[axis] = axis_fragm * local_axis_shape + global_offset[axis] = axis_rank_offset * local_axis_shape + axis_fragmentations[axis] = axis_fragm + + return cls( + key, + data, + data.dtype, + tuple(data.shape), + tuple(global_shape), + tuple(global_offset), + tuple(axis_fragmentations), + replica_id, + prepend_axis_num, + flattened_range=flattened_range, + **init_kwargs, + ) + + @classmethod + def from_rank_offsets_flat( + cls, + key: str, + data: torch.Tensor, + non_flat_local_shape: Tuple[int, ...], + *args, + flattened_range: Optional[slice] = None, + **kwargs, + ): + """Allows to construct a *flattened* ShardedTensor given offset specified in process ranks. + + Args: + key (str): + data (torch.Tensor): this should be a flattened data tensor + non_flat_local_shape (Tuple[int, ...]): expected local shape of a non-flat chunk + *args: passed unchanged to the `from_rank_offsets` constructor + flattened_range (slice): see ShardedTensor. Defaults to None, but must be set to + a non-None slice. + **kwargs: + + Returns: + ShardedTensor: constructed ShardedTensor instance + """ + if flattened_range is None: + raise CheckpointingException( + 'Cannot instantiate a non-flat ShardedTensor with `from_rank_offsets_flat` method.' + ' Use `from_rank_offsets` instead' + ) + if data.ndim != 1: + raise CheckpointingException( + f'Flattened ShardedTensor requires 1D data, got shape: {data.shape}' + ) + if flattened_range.stop - flattened_range.start != data.numel(): + raise CheckpointingException( + f'Flattened ShardedTensor data length ({data.numel()}) must meet the ' + f'slice length: {flattened_range.stop - flattened_range.start}' + ) + + non_flat_data_meta = torch.empty(*non_flat_local_shape, dtype=data.dtype, device='meta') + sh_ten = cls.from_rank_offsets(key, non_flat_data_meta, *args, **kwargs) + instance = replace(sh_ten, data=data, flattened_range=flattened_range) + instance.validate_metadata_integrity() + return instance + + def init_data(self, device: Union[str, torch.device], init_fn=torch.empty): + """ + Initialize the tensor data of this ShardedTensor. + + Only called if `data` attribute is None. + + Args: + device (Union[str, torch.device]): device to place the tensor on + init_fn (Callable, optional): function to use to initialize the tensor. + Defaults to `torch.empty`. + """ + if self.data is not None: + return + self.data = init_fn(self.local_shape, dtype=self.dtype, device=device) + if self.flattened_range is not None: + self.data = self.data.flatten()[self.flattened_range.start : self.flattened_range.stop] + + def narrow(self, dim: int, start: int, length: int) -> List['ShardedTensor']: + """This is an analogue of torch.narrow for ShardedTensors. + + Narrowing assumes that we narrow a local tensor on each rank. + This has consequences on local_shape, global_shape, global_offset, etc. + + Args: + dim (int): dimension to narrow. Doesn't include prepended axes. + start (int): start element + length (int): length of the slice + + Returns: + List[ShardedTensor]: narrowed ShardedTensors. For non-flat tensors, + the list will always have 1 element. For flat ShardedTensors the number of + elements varies depending on `dim` and on overlap, because flat + tensors must be contiguous. In particular the list can be empty. + """ + prepended_dim = dim + self.prepend_axis_num + local_length_along_dim = self.local_shape[dim] + + def _update_tuple(x, ind, val): + x = list(x) + x[ind] = val + return tuple(x) + + def _safe_div(x, y): + assert x % y == 0, (x, y) + return x // y + + # Decrease global shape and global offset by `length / local_length_along_dim` + assert ( + self.global_shape[prepended_dim] % local_length_along_dim == 0 + ), f'Only regular grid of local tensors is supported for narrowing, got: {self}' + assert ( + self.global_offset[prepended_dim] % local_length_along_dim == 0 + ), f'Only regular grid of local tensors is supported for narrowing, got: {self}' + global_shape = _update_tuple( + self.global_shape, + prepended_dim, + _safe_div(self.global_shape[prepended_dim] * length, local_length_along_dim), + ) + global_offset = _update_tuple( + self.global_offset, + prepended_dim, + _safe_div(self.global_offset[prepended_dim] * length, local_length_along_dim), + ) + + if self.flattened_range is None: + new_data = self.data.narrow(dim, start, length) + # always a single result tensor + return [ + replace( + self, + data=new_data, + local_shape=new_data.shape, + global_shape=global_shape, + global_offset=global_offset, + ) + ] + else: + if dim != 0: + raise CheckpointingException( + f'Narrowing along the first axis is supported for now only, got dim={dim}' + ) + + # If dim=0, we will always get 0 or 1 resulting tensor. + # If dim>1, in general there can be more result tensors (e.g. max 3 for dim=1) + + # For on original flat ShardedTensor of local shape [3, 4] and + # flattened_range=slice(5, 10), + # the X signs mark the actual (flat) data in `self.data` + # notice 12 (3*4) total "virtual" elements, out of which 5 is actual data. + # flat original: [.....XXXXX..] + + # If we narrow to start=1, length=1 in the original local shape dimensions, + # the overlapping flat slice would be: + # narrow to: [....XXXX....] + # flat overlap: [.....XXX....] + + # Now `data` is flattened and sliced, so we must compute local_shape manually + local_shape = _update_tuple(self.local_shape, dim, length) + other_dims_volume = np.prod( + _update_tuple(local_shape, dim, 1) + ) # 4 in the example above + volume_before_split = other_dims_volume * start # 4 in the example above + volume_of_split = other_dims_volume * length # 4 in the example above + + flat_slice_start_shifted = ( + self.flattened_range.start - volume_before_split + ) # 5 - 4 = 1 in the example above + flat_slice_stop_shifted = ( + self.flattened_range.stop - volume_before_split + ) # 10 - 4 = 6 in the example above + + # Find an intersection of + # (flat_slice_start_shifted, flat_slice_stop_shifted) vs (0, volume_of_split) + + if flat_slice_stop_shifted <= 0 or flat_slice_start_shifted >= volume_of_split: + return [] # no intersection + + # new_flattened_range = slice(1, 4) in the example above + new_flattened_range = slice( + max(flat_slice_start_shifted, 0), min(flat_slice_stop_shifted, volume_of_split) + ) + # Apply the intersection to the flattened data tensor. + # Compute start and slice appropriate length + intersection_slice_start = ( + new_flattened_range.start - flat_slice_start_shifted + ) # 0 in the example above + new_data = self.data[ + intersection_slice_start : intersection_slice_start + + new_flattened_range.stop + - new_flattened_range.start + ] + + return [ + replace( + self, + data=new_data, + local_shape=local_shape, + global_shape=global_shape, + global_offset=global_offset, + flattened_range=new_flattened_range, + ) + ] + + +def is_main_replica(replica_id: ReplicaId): + """Checks if given `replica_id` is considered as main. + + "Main" replica is: + - integer 0 + - or an iterable with all 0 elements + + It is the application responsibility to set correct replicas for sharded tensors. + + Args: + replica_id (Union[int, Tuple[int, ...]]): replica id + + Returns: + (bool): True for a "main" replica + """ + if isinstance(replica_id, int): + return replica_id == 0 + return all(r == 0 for r in replica_id) + + +class LocalNonpersistentObject: + """Object that should not be stored in a checkpoint, but restored locally. + + Wrapping any object inside the state dict with LocalNonpersistentObject + will result in: + - during saving, this object will *not* be stored in the checkpoint + - during loading, a local version of this object will be placed in a state dict + """ + + def __init__(self, obj): + self.obj = obj + + def unwrap(self): + """Returns the original object.""" + return self.obj + + +@dataclass +class ShardedObject(ShardedBase): + """Represents a mapping between a local object and a global object. + + Global object is assumed to consist of many local objects distributed + between different processes. + + NOTE: Contrary to ShardedTensor, it's impossible to change global object + sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor + with atomic arbitrary typed elements. + + Args: + key: unique identifier of a global tensor + data: local object data. Can be None only for consistency validation + global_shape: global object shape + global_offset: offset of a local object in a global object, specified in number of shards + replica_id: indicates local object replication wrt. local objects in different processes + """ + + key: str + data: object + global_shape: Tuple[int, ...] + global_offset: Tuple[int, ...] + replica_id: ReplicaId = 0 + + def __post_init__(self): + self.validate_metadata_integrity() + + def validate_metadata_integrity(self): + if len(self.global_shape) != len(self.global_offset): + raise CheckpointingException( + f'Global offset dimensions should be equal to global shape dimensions for {self}' + ) + + def without_data(self): + return replace(self, data=None) + + @property + def unique_key(self): + """returns a unique key for this object""" + return ( + f'{self.key}/shard_' + f'{".".join(map(str, self.global_offset))}_' + f'{".".join(map(str, self.global_shape))}' + ) + + def __str__(self): + return f'{self.__class__.__name__}(key=\'{self.key}\')' + + @classmethod + def empty_from_unique_key(cls, unique_key, replica_id: ReplicaId = 0) -> 'ShardedObject': + """Instantiates a ShardedObject from a unique key. + + Args: + unique_key: a string of the form + /shard__ + replica_id: indicates local object replication wrt. + local objects in different processes + + Returns: + a ShardedObject with data=None + """ + key, shard_key = unique_key.split('/') + shard_str, offset, shape = shard_key.split('_') + assert shard_str == 'shard' + offset = tuple(map(int, offset.split('.'))) + shape = tuple(map(int, shape.split('.'))) + if len(shape) + 1 == len(offset): + # This is a backward-compatible fix. We don't know the last + # element of global shape so set it to -1. + shape += (-1,) + return cls(key, None, shape, offset, replica_id) + + +FactoryBuildFn = Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict] +FactoryMergeFn = Callable[[StateDict], torch.Tensor] + + +@dataclass +class ShardedTensorFactory(ShardedBase): + """Allows to apply transformations to tensors before/after serialization. + + The essence of those transformations is that they can be applied to + optimizer states the same way they are applied to the model params. + The ultimate state dict with sharded tensors must depend functionally on + `build_fn` arguments (key, data, replica_id, flattened_range), + which will be provided by the optimizer. + + Builder creates a sub-state-dict out of a tensor before saving, and merger + merges the corresponding state dict after loading. + + Args: + key (str): unique identifier of the factory + data (torch.Tensor): original model parameter that will be further + transformed by this factory + build_fn (callable): function that transforms the original tensor + to a sharded state dict + merge_fn (callable): function that transforms loaded subtree back + into a single tensor (inverse of `build_fn`) + replica_id (ReplicaId): indicates factory replication wrt. + factories in different processes + flattened_range (slice, optional): indicates additional flattening + applied to the ShardedTensors produced by the factory + """ + + key: str + data: torch.Tensor + build_fn: FactoryBuildFn + merge_fn: FactoryMergeFn + replica_id: ReplicaId = 0 + flattened_range: Optional[slice] = None + + def build(self): + """Builds a ShardedStateDict from the original tensor""" + return self.build_fn(self.key, self.data, self.replica_id, self.flattened_range) + + def validate_metadata_integrity(self): + """No reasonable checks can be applied""" + pass + + def without_data(self): + return replace(self, data=None) + + +def apply_factories(sharded_state_dict: ShardedStateDict): + """Turn ShardedTensorFactories into ShardedTensors *in-place*. + + Args: + sharded_state_dict (ShardedStateDict): state dict possibly + containing ShardedTensorFactory objects + + Returns: + None: state dict is modified in place + """ + + def apply(x): + if isinstance(x, ShardedTensorFactory): + x = x.build() + return x + + dict_list_map_inplace(apply, sharded_state_dict) + + +def apply_factory_merges( + x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = () +) -> StateDict: + """Apply merges defined by ShardedTensorFactories *in-place*. + + Args: + x1 (StateDict): state dict loaded from the checkpoint + x2 (ShardedStateDict): subset of `x1` (in terms of dict keys) + with ShardedTensorFactory + as (possibly nested) values that define how to + merge objects from the `x1` state dict + key (Tuple[str, ...]): current key in a recursive call. + Used only for reporting meaningful errors + + Returns: + StateDict: `x1` modified in-place + """ + if isinstance(x2, ShardedTensorFactory): + return x2.merge_fn(x1) + + # There rest is almost the same as the `merge` function from `dict_utils` + if isinstance(x1, dict) and isinstance(x2, dict): + for k, v2 in x2.items(): + if k not in x1: + raise ValueError( + f'Different dict keys encountered in `apply_factory_merges` ' + f'({x1.keys()} vs {x2.keys()})' + ) + else: + x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,)) + elif isinstance(x1, list) and isinstance(x2, list): + if len(x1) != len(x2): + err_msg = ( + f'Cannot merge two lists with different lengths ' + f'({len(x1)} and {len(x2)}, encountered at key {key})' + ) + logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}') + raise ValueError(err_msg) + for i, v2 in enumerate(x2): + x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,)) + elif isinstance(x1, list) and isinstance(x2, dict): + for k, v2 in x2.items(): + if not isinstance(k, int): + raise ValueError( + f'Invalid dict key {k} non-integer type encountered ' + f'in a list-dict merge at level {key}' + ) + if k >= len(x1): + raise ValueError( + f'Dict key {k} out of bound for list of length' + f'{len(x1)} (encountered at level {key})' + ) + x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,)) + else: + raise ValueError( + f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`' + ) + return x1 diff --git a/megatron/core/dist_checkpointing/serialization.py b/megatron/core/dist_checkpointing/serialization.py index 600dd87e5438620a9214d97f2c688f056e5c4aef..b8c7badf547836ad37d1d5907f1f554e7eb57d49 100644 --- a/megatron/core/dist_checkpointing/serialization.py +++ b/megatron/core/dist_checkpointing/serialization.py @@ -1,424 +1,424 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Entrypoints for saving and loading the distributed checkpoints. - -Functions `load` and `save` are equivalents of `torch.load` and `torch.save` -but expect torch.Tensors to be wrapped with classes from the `mapping module`. -Additionally, `load` expects the sharded state dict argument as a guidance for -loading the sharded tensors. -""" - -import logging -from pathlib import Path -from typing import Callable, Dict, Optional, Set, Tuple, Union - -import torch - -from . import ShardedTensor -from .core import CheckpointingConfig, save_config -from .dict_utils import extract_matching_values, merge -from .mapping import ( - CheckpointingException, - CommonStateDict, - ShardedObject, - ShardedStateDict, - StateDict, - apply_factory_merges, -) -from .state_dict_transformation import load_preprocess, save_preprocess -from .strategies.async_utils import AsyncRequest -from .strategies.base import ( - AsyncSaveShardedStrategy, - LoadCommonStrategy, - LoadShardedStrategy, - SaveCommonStrategy, - SaveShardedStrategy, - StrategyAction, - get_default_strategy, -) -from .utils import extract_sharded_base -from .validation import ( - StrictHandling, - determine_global_metadata, - parse_strict_flag, - validate_integrity_and_strict_load, - validate_sharded_objects_handling, - verify_checkpoint_and_load_strategy, -) - -logger = logging.getLogger(__name__) - - -# flat state dict with sharded objects without any data -CkptShardedMetadata = Dict[str, Union[ShardedTensor, ShardedObject]] - - -def load( - sharded_state_dict: ShardedStateDict, - checkpoint_dir: str, - sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None, - common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None, - validate_access_integrity: bool = True, - strict: Union[str, StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED, -) -> Union[StateDict, Tuple[StateDict, Set[str], Set[str]]]: - """Loading entrypoint. - - In the steps below, the following verbs refer to corresponding objects: - - load = load from checkpoint - - extract = extract from sharded_state_dict - - add = add to the final state dict - Steps: - 1. Load common state dict and form the base of the result state dict - 2. Apply factories to sharded_state_dict - 3. Extract LocalNonPersistentObject and add - 4. (optional) Extract ShardedObjects, load and add - 5. Extract ShardedBase, load, apply factory merges and add - - Args: - sharded_state_dict (ShardedStateDict): state dict of the existing model - populated with ShardedTensors. Used as a mapping to determine which - parts of global tensors stored in the checkpoint should be loaded. - checkpoint_dir (str): directory with the checkpoint - sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): - configures loading behavior for sharded tensors - common_strategy (LoadCommonStrategy, Tuple[str, int], optional): - configures loading behavior for common data - validate_access_integrity (bool default = True): checks if each tensor shard is accessed - exactly once (as main replica) by some process - strict (StrictHandling, str, optional): determines the behavior in case of a mismatch - between the requested sharded state dict and the checkpoint. See `StrictHandling` docs - for more details. Some values affect the return value of this function - (missing and unexpected keys are returned). - Defaults to `True` (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn't - incur any performance overhead. Other recommended values - are: `False` (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys - or `StrictHandling.RETURN_ALL` which returns all mismatch keys. - - Returns: - StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only - the loaded state dict is returned. If `strict` flag was set to - """ - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( - checkpoint_dir, sharded_strategy, common_strategy - ) - - checkpoint_dir = Path(checkpoint_dir) - common_state_dict = common_strategy.load_common(checkpoint_dir) - - sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( - sharded_state_dict - ) - merge(common_state_dict, nonpersistent_state_dict) - - # At this point we are only dealing with ShardedBase objects - sharded_state_dict, _ = extract_sharded_base(sharded_state_dict) - - # Validation - ckpt_sharded_metadata = None - local_metadata, global_metadata = None, None - strict = parse_strict_flag(strict) - if StrictHandling.requires_explicit_ckpt_mismatch_check(strict): - ckpt_sharded_metadata = load_sharded_metadata( - str(checkpoint_dir), sharded_strategy, common_strategy - ) - if validate_access_integrity or StrictHandling.requires_global_app_metadata(strict): - local_metadata, global_metadata = determine_global_metadata(sharded_state_dict) - - sharded_state_dict, missing_keys, unexpected_keys = validate_integrity_and_strict_load( - sharded_state_dict, - strict, - validate_access_integrity, - local_metadata, - global_metadata, - ckpt_sharded_metadata, - ) - - # ShardedBase loading - if not sharded_strategy.can_handle_sharded_objects: - validate_sharded_objects_handling(sharded_strategy, common_strategy) - sharded_objects_state_dict, sharded_state_dict = extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, ShardedObject) - ) - sharded_objects = common_strategy.load_sharded_objects( - sharded_objects_state_dict, checkpoint_dir - ) - merge(common_state_dict, sharded_objects) - - loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir) - - merge(common_state_dict, loaded_state_dict) - - loaded_state_dict = apply_factory_merges(common_state_dict, sh_ten_factories) - - if StrictHandling.requires_returning_mismatch_keys(strict): - return common_state_dict, missing_keys, unexpected_keys - else: - return common_state_dict - - -def load_common_state_dict(checkpoint_dir: Path) -> StateDict: - """Load common (non-sharded) objects state dict from the checkpoint. - - Args: - checkpoint_dir (Path): checkpoint directory - - Returns: - StateDict: state dict with non-sharded objects from the checkpoint - """ - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(str(checkpoint_dir)) - return common_strategy.load_common(checkpoint_dir) - - -def load_tensors_metadata( - checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None -) -> CkptShardedMetadata: - """Load tensors metadata from the checkpoint. - - Returns a dictionary similar to a sharded state dict, but note that - the dictionary keys are simply ShardedTensor keys (contrary to the - actual sharded state dicts where keys correspond to state dict keys). - - Dict values are ShardedTensors without any sharding (so, the only useful - information is tensors global shape and dtype). - - Concrete implementation depends on the loading strategy. If no strategy is - given, a default for a given backend is used. - - Args: - checkpoint_dir (str): checkpoint directory to load from - sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata. - Defaults to None - in this case a default load strategy for a given checkpoint type - is used. - - Returns: - CkptShardedMetadata: flat state dict without data describing ShardedTensors - in the checkpoint - """ - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( - checkpoint_dir, sharded_strategy - ) - return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir)) - - -def load_sharded_metadata( - checkpoint_dir: str, - sharded_strategy: Union[LoadShardedStrategy, None] = None, - common_strategy: Union[LoadCommonStrategy, None] = None, -) -> CkptShardedMetadata: - """Load sharded metadata from the checkpoint. - - Similar to `load_tensors_metadata`, but includes also ShardedObjects. - - Returns a dictionary similar to a sharded state dict, but note that - the dictionary keys are simply ShardedTensor keys (contrary to the - actual sharded state dicts where keys correspond to state dict keys). - - Dict values are ShardedTensors without any sharding (so, the only useful - information is tensors global shape and dtype). - - Concrete implementation depends on the loading strategy. If no strategy is - given, a default for a given backend is used. - - Args: - checkpoint_dir (str): checkpoint directory to load from - sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata. - Defaults to None - in this case a default load strategy for a given checkpoint type - is used. - common_strategy (LoadCommonStrategy, optional): common strategy to load metadata. - Defaults to None - in this case a default load strategy for a given checkpoint type is - used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects - - Returns: - CkptShardedMetadata: flat state dict without data describing ShardedTensors - and ShardedObjects in the checkpoint - """ - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( - checkpoint_dir, sharded_strategy, common_strategy - ) - sharded_metadata = sharded_strategy.load_sharded_metadata(Path(checkpoint_dir)) - if not sharded_strategy.can_handle_sharded_objects: - validate_sharded_objects_handling(sharded_strategy, common_strategy) - common_metadata = common_strategy.load_sharded_metadata(Path(checkpoint_dir)) - sharded_metadata = merge(sharded_metadata, common_metadata) - return sharded_metadata - - -def load_plain_tensors(checkpoint_dir: str) -> StateDict: - """Load checkpoint tensors without any sharding and plain structure. - - NOTE: common state dict is NOT included. - - Args: - checkpoint_dir (str): checkpoint directory to load the tensors from. - - Returns: - StateDict: checkpoint state dict containing only torch.Tensors. - """ - sharded_state_dict = load_tensors_metadata(checkpoint_dir) - # Don't validate integrity because shards will be overlapped - # if world_size > 1 (all processes load whole tensors) - return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False) - - -# -# def load_plain_tensors_and_objects(checkpoint_dir: str) -> StateDict: -# """Load checkpoint tensors and objects without any sharding and plain structure. -# -# NOTE: state dict structure might be different than the one used for checkpoint saving. -# NOTE: common state dict is NOT included. -# -# Args: -# checkpoint_dir (str): checkpoint directory to load the state dict from. -# -# Returns: -# StateDict: complete checkpoint state dict without any sharding. -# """ -# sharded_state_dict = load_tensors_metadata(checkpoint_dir) -# # Don't validate integrity because shards will be overlapped -# # if world_size > 1 (all processes load whole tensors) -# return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False) - - -def remove_sharded_tensors(checkpoint_dir: str, key_prefix: str): - """determine the appropriate sharding strategy and delegate removal to the sharded strategy""" - sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir) - sharded_strategy.remove_sharded_tensors(checkpoint_dir, key_prefix) - - -def save( - sharded_state_dict: ShardedStateDict, - checkpoint_dir: str, - sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None, - common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None, - validate_access_integrity: bool = True, - async_sharded_save: bool = False, - preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None, -) -> Optional[AsyncRequest]: - """Saving entrypoint. - - Extracts ShardedTensors from the given state dict. Rank 0 saves the - "regular" part of the checkpoint to common torch file. - The ShardedTensors are saved according to a strategy specified by the - config. - - Steps: - 1. Apply factories - 2. Extract and discard LocalNonPersistentObject - 3. Extract all ShardedBase object - 4. Save all other objects to common.pt - 5. (optional) Extract and save ShardedObjects - 6. Save all ShardedBase objects - 7. Write metadata.json file with backend and version metadata. - - Step (6) can be performed asynchronously (see `async_sharded_save`), in this - case the actual save is embodied in the returned async request and can be - scheduled by the external caller. For async request, step (7) is added as - one of the finalization functions, so that metadata.json is written only - if the checkpoint is complete. - - Args: - sharded_state_dict (ShardedStateDict): state dict of the populated with - ShardedTensors. Used as a mapping to determine how local tensors - should be saved as global tensors in the checkpoint. - checkpoint_dir (str): directory to save the checkpoint to - sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional): - configures sharded tensors saving behavior and backend - common_strategy (SaveCommonStrategy, Tuple[str, int], optional): - configures common data saving behavior and backend - validate_access_integrity (bool default = True): checks if each tensor shard is accessed - exactly once (as main replica) by some process. - It also makes sure the common state dict is consistant across all ranks - async_sharded_save (bool, optional): if True, for the sharded state dict part - an async save implementation will be called, with the AsyncRequest - being returned to the caller. Note that it is the caller responsibility to - actually schedule the async save. Defaults to False. - preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None): - A callable function that will preprocess the common state dict (i.e can be used to - remove keys that we expect to be different in the state dict). The function must not - modify the original state dict - - Returns: - AsyncRequest (optional): if `async_sharded_save` is True, returns - async request that should be scheduled by the caller of this function. - None otherwise. - """ - checkpoint_dir = Path(checkpoint_dir) - - if torch.distributed.get_rank() == 0: - if not checkpoint_dir.exists(): - raise CheckpointingException( - f'Checkpoint destination directory does not exist: {checkpoint_dir}' - ) - - if next(checkpoint_dir.iterdir(), None) is not None: - raise CheckpointingException( - f'Checkpoint destination directory ({checkpoint_dir}) is not empty' - ) - - if common_strategy is not None: - raise NotImplementedError('The only supported common strategy is torch') - - if sharded_strategy is None: - sharded_strategy = get_default_save_sharded_strategy() - if not isinstance(sharded_strategy, SaveShardedStrategy): - assert isinstance(sharded_strategy, tuple), type(sharded_strategy) - sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy) - - if common_strategy is None: - common_strategy = get_default_save_common_strategy() - if not isinstance(common_strategy, SaveCommonStrategy): - assert isinstance(common_strategy, tuple), type(common_strategy) - common_strategy = get_default_strategy(StrategyAction.SAVE_COMMON, *common_strategy) - - sharded_state_dict, state_dict = save_preprocess( - sharded_state_dict, validate_access_integrity, preprocess_common_before_consistancy_check - ) - - common_strategy.save_common(state_dict, checkpoint_dir) - - if not sharded_strategy.can_handle_sharded_objects: - validate_sharded_objects_handling(sharded_strategy, common_strategy) - sharded_objects_state_dict, sharded_state_dict = extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, ShardedObject) - ) - common_strategy.save_sharded_objects(sharded_objects_state_dict, checkpoint_dir) - - def metadata_finalize_fn(): - if torch.distributed.get_rank() == 0: - save_config( - CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), - checkpoint_dir, - ) - torch.distributed.barrier() - - if not async_sharded_save: - sharded_strategy.save(sharded_state_dict, checkpoint_dir) - metadata_finalize_fn() - return - - if not isinstance(sharded_strategy, AsyncSaveShardedStrategy): - raise CheckpointingException( - f'Cannot apply async_save to non-async strategy {sharded_strategy}' - ) - async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir) - async_request.finalize_fns.append(metadata_finalize_fn) - return async_request - - -def get_default_save_sharded_strategy( - backend: str = 'torch_dist', version: int = 1 -) -> SaveShardedStrategy: - """Get default save sharded strategy.""" - return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version) - - -def get_default_save_common_strategy( - backend: str = 'torch', version: int = 1 -) -> SaveCommonStrategy: - """Get default save common strategy.""" - return get_default_strategy(StrategyAction.SAVE_COMMON, backend, version) - - -def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy: - """Get default load sharded strategy.""" - return verify_checkpoint_and_load_strategy(checkpoint_dir)[0] +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Entrypoints for saving and loading the distributed checkpoints. + +Functions `load` and `save` are equivalents of `torch.load` and `torch.save` +but expect torch.Tensors to be wrapped with classes from the `mapping module`. +Additionally, `load` expects the sharded state dict argument as a guidance for +loading the sharded tensors. +""" + +import logging +from pathlib import Path +from typing import Callable, Dict, Optional, Set, Tuple, Union + +import torch + +from . import ShardedTensor +from .core import CheckpointingConfig, save_config +from .dict_utils import extract_matching_values, merge +from .mapping import ( + CheckpointingException, + CommonStateDict, + ShardedObject, + ShardedStateDict, + StateDict, + apply_factory_merges, +) +from .state_dict_utils import load_preprocess, save_preprocess +from .strategies.async_utils import AsyncRequest +from .strategies.base import ( + AsyncSaveShardedStrategy, + LoadCommonStrategy, + LoadShardedStrategy, + SaveCommonStrategy, + SaveShardedStrategy, + StrategyAction, + get_default_strategy, +) +from .utils import extract_sharded_base +from .validation import ( + StrictHandling, + determine_global_metadata, + parse_strict_flag, + validate_integrity_and_strict_load, + validate_sharded_objects_handling, + verify_checkpoint_and_load_strategy, +) + +logger = logging.getLogger(__name__) + + +# flat state dict with sharded objects without any data +CkptShardedMetadata = Dict[str, Union[ShardedTensor, ShardedObject]] + + +def load( + sharded_state_dict: ShardedStateDict, + checkpoint_dir: str, + sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None, + common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None, + validate_access_integrity: bool = True, + strict: Union[str, StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED, +) -> Union[StateDict, Tuple[StateDict, Set[str], Set[str]]]: + """Loading entrypoint. + + In the steps below, the following verbs refer to corresponding objects: + - load = load from checkpoint + - extract = extract from sharded_state_dict + - add = add to the final state dict + Steps: + 1. Load common state dict and form the base of the result state dict + 2. Apply factories to sharded_state_dict + 3. Extract LocalNonPersistentObject and add + 4. (optional) Extract ShardedObjects, load and add + 5. Extract ShardedBase, load, apply factory merges and add + + Args: + sharded_state_dict (ShardedStateDict): state dict of the existing model + populated with ShardedTensors. Used as a mapping to determine which + parts of global tensors stored in the checkpoint should be loaded. + checkpoint_dir (str): directory with the checkpoint + sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): + configures loading behavior for sharded tensors + common_strategy (LoadCommonStrategy, Tuple[str, int], optional): + configures loading behavior for common data + validate_access_integrity (bool default = True): checks if each tensor shard is accessed + exactly once (as main replica) by some process + strict (StrictHandling, str, optional): determines the behavior in case of a mismatch + between the requested sharded state dict and the checkpoint. See `StrictHandling` docs + for more details. Some values affect the return value of this function + (missing and unexpected keys are returned). + Defaults to `True` (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn't + incur any performance overhead. Other recommended values + are: `False` (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys + or `StrictHandling.RETURN_ALL` which returns all mismatch keys. + + Returns: + StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only + the loaded state dict is returned. If `strict` flag was set to + """ + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( + checkpoint_dir, sharded_strategy, common_strategy + ) + + checkpoint_dir = Path(checkpoint_dir) + common_state_dict = common_strategy.load_common(checkpoint_dir) + + sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( + sharded_state_dict + ) + merge(common_state_dict, nonpersistent_state_dict) + + # At this point we are only dealing with ShardedBase objects + sharded_state_dict, _ = extract_sharded_base(sharded_state_dict) + + # Validation + ckpt_sharded_metadata = None + local_metadata, global_metadata = None, None + strict = parse_strict_flag(strict) + if StrictHandling.requires_explicit_ckpt_mismatch_check(strict): + ckpt_sharded_metadata = load_sharded_metadata( + str(checkpoint_dir), sharded_strategy, common_strategy + ) + if validate_access_integrity or StrictHandling.requires_global_app_metadata(strict): + local_metadata, global_metadata = determine_global_metadata(sharded_state_dict) + + sharded_state_dict, missing_keys, unexpected_keys = validate_integrity_and_strict_load( + sharded_state_dict, + strict, + validate_access_integrity, + local_metadata, + global_metadata, + ckpt_sharded_metadata, + ) + + # ShardedBase loading + if not sharded_strategy.can_handle_sharded_objects: + validate_sharded_objects_handling(sharded_strategy, common_strategy) + sharded_objects_state_dict, sharded_state_dict = extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, ShardedObject) + ) + sharded_objects = common_strategy.load_sharded_objects( + sharded_objects_state_dict, checkpoint_dir + ) + merge(common_state_dict, sharded_objects) + + loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir) + + merge(common_state_dict, loaded_state_dict) + + loaded_state_dict = apply_factory_merges(common_state_dict, sh_ten_factories) + + if StrictHandling.requires_returning_mismatch_keys(strict): + return common_state_dict, missing_keys, unexpected_keys + else: + return common_state_dict + + +def load_common_state_dict(checkpoint_dir: Path) -> StateDict: + """Load common (non-sharded) objects state dict from the checkpoint. + + Args: + checkpoint_dir (Path): checkpoint directory + + Returns: + StateDict: state dict with non-sharded objects from the checkpoint + """ + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(str(checkpoint_dir)) + return common_strategy.load_common(checkpoint_dir) + + +def load_tensors_metadata( + checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None +) -> CkptShardedMetadata: + """Load tensors metadata from the checkpoint. + + Returns a dictionary similar to a sharded state dict, but note that + the dictionary keys are simply ShardedTensor keys (contrary to the + actual sharded state dicts where keys correspond to state dict keys). + + Dict values are ShardedTensors without any sharding (so, the only useful + information is tensors global shape and dtype). + + Concrete implementation depends on the loading strategy. If no strategy is + given, a default for a given backend is used. + + Args: + checkpoint_dir (str): checkpoint directory to load from + sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata. + Defaults to None - in this case a default load strategy for a given checkpoint type + is used. + + Returns: + CkptShardedMetadata: flat state dict without data describing ShardedTensors + in the checkpoint + """ + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( + checkpoint_dir, sharded_strategy + ) + return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir)) + + +def load_sharded_metadata( + checkpoint_dir: str, + sharded_strategy: Union[LoadShardedStrategy, None] = None, + common_strategy: Union[LoadCommonStrategy, None] = None, +) -> CkptShardedMetadata: + """Load sharded metadata from the checkpoint. + + Similar to `load_tensors_metadata`, but includes also ShardedObjects. + + Returns a dictionary similar to a sharded state dict, but note that + the dictionary keys are simply ShardedTensor keys (contrary to the + actual sharded state dicts where keys correspond to state dict keys). + + Dict values are ShardedTensors without any sharding (so, the only useful + information is tensors global shape and dtype). + + Concrete implementation depends on the loading strategy. If no strategy is + given, a default for a given backend is used. + + Args: + checkpoint_dir (str): checkpoint directory to load from + sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata. + Defaults to None - in this case a default load strategy for a given checkpoint type + is used. + common_strategy (LoadCommonStrategy, optional): common strategy to load metadata. + Defaults to None - in this case a default load strategy for a given checkpoint type is + used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects + + Returns: + CkptShardedMetadata: flat state dict without data describing ShardedTensors + and ShardedObjects in the checkpoint + """ + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( + checkpoint_dir, sharded_strategy, common_strategy + ) + sharded_metadata = sharded_strategy.load_sharded_metadata(Path(checkpoint_dir)) + if not sharded_strategy.can_handle_sharded_objects: + validate_sharded_objects_handling(sharded_strategy, common_strategy) + common_metadata = common_strategy.load_sharded_metadata(Path(checkpoint_dir)) + sharded_metadata = merge(sharded_metadata, common_metadata) + return sharded_metadata + + +def load_plain_tensors(checkpoint_dir: str) -> StateDict: + """Load checkpoint tensors without any sharding and plain structure. + + NOTE: common state dict is NOT included. + + Args: + checkpoint_dir (str): checkpoint directory to load the tensors from. + + Returns: + StateDict: checkpoint state dict containing only torch.Tensors. + """ + sharded_state_dict = load_tensors_metadata(checkpoint_dir) + # Don't validate integrity because shards will be overlapped + # if world_size > 1 (all processes load whole tensors) + return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False) + + +# +# def load_plain_tensors_and_objects(checkpoint_dir: str) -> StateDict: +# """Load checkpoint tensors and objects without any sharding and plain structure. +# +# NOTE: state dict structure might be different than the one used for checkpoint saving. +# NOTE: common state dict is NOT included. +# +# Args: +# checkpoint_dir (str): checkpoint directory to load the state dict from. +# +# Returns: +# StateDict: complete checkpoint state dict without any sharding. +# """ +# sharded_state_dict = load_tensors_metadata(checkpoint_dir) +# # Don't validate integrity because shards will be overlapped +# # if world_size > 1 (all processes load whole tensors) +# return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False) + + +def remove_sharded_tensors(checkpoint_dir: str, key_prefix: str): + """determine the appropriate sharding strategy and delegate removal to the sharded strategy""" + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir) + sharded_strategy.remove_sharded_tensors(checkpoint_dir, key_prefix) + + +def save( + sharded_state_dict: ShardedStateDict, + checkpoint_dir: str, + sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None, + common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None, + validate_access_integrity: bool = True, + async_sharded_save: bool = False, + preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None, +) -> Optional[AsyncRequest]: + """Saving entrypoint. + + Extracts ShardedTensors from the given state dict. Rank 0 saves the + "regular" part of the checkpoint to common torch file. + The ShardedTensors are saved according to a strategy specified by the + config. + + Steps: + 1. Apply factories + 2. Extract and discard LocalNonPersistentObject + 3. Extract all ShardedBase object + 4. Save all other objects to common.pt + 5. (optional) Extract and save ShardedObjects + 6. Save all ShardedBase objects + 7. Write metadata.json file with backend and version metadata. + + Step (6) can be performed asynchronously (see `async_sharded_save`), in this + case the actual save is embodied in the returned async request and can be + scheduled by the external caller. For async request, step (7) is added as + one of the finalization functions, so that metadata.json is written only + if the checkpoint is complete. + + Args: + sharded_state_dict (ShardedStateDict): state dict of the populated with + ShardedTensors. Used as a mapping to determine how local tensors + should be saved as global tensors in the checkpoint. + checkpoint_dir (str): directory to save the checkpoint to + sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional): + configures sharded tensors saving behavior and backend + common_strategy (SaveCommonStrategy, Tuple[str, int], optional): + configures common data saving behavior and backend + validate_access_integrity (bool default = True): checks if each tensor shard is accessed + exactly once (as main replica) by some process. + It also makes sure the common state dict is consistant across all ranks + async_sharded_save (bool, optional): if True, for the sharded state dict part + an async save implementation will be called, with the AsyncRequest + being returned to the caller. Note that it is the caller responsibility to + actually schedule the async save. Defaults to False. + preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None): + A callable function that will preprocess the common state dict (i.e can be used to + remove keys that we expect to be different in the state dict). The function must not + modify the original state dict + + Returns: + AsyncRequest (optional): if `async_sharded_save` is True, returns + async request that should be scheduled by the caller of this function. + None otherwise. + """ + checkpoint_dir = Path(checkpoint_dir) + + if torch.distributed.get_rank() == 0: + if not checkpoint_dir.exists(): + raise CheckpointingException( + f'Checkpoint destination directory does not exist: {checkpoint_dir}' + ) + + if next(checkpoint_dir.iterdir(), None) is not None: + raise CheckpointingException( + f'Checkpoint destination directory ({checkpoint_dir}) is not empty' + ) + + if common_strategy is not None: + raise NotImplementedError('The only supported common strategy is torch') + + if sharded_strategy is None: + sharded_strategy = get_default_save_sharded_strategy() + if not isinstance(sharded_strategy, SaveShardedStrategy): + assert isinstance(sharded_strategy, tuple), type(sharded_strategy) + sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy) + + if common_strategy is None: + common_strategy = get_default_save_common_strategy() + if not isinstance(common_strategy, SaveCommonStrategy): + assert isinstance(common_strategy, tuple), type(common_strategy) + common_strategy = get_default_strategy(StrategyAction.SAVE_COMMON, *common_strategy) + + sharded_state_dict, state_dict = save_preprocess( + sharded_state_dict, validate_access_integrity, preprocess_common_before_consistancy_check + ) + + common_strategy.save_common(state_dict, checkpoint_dir) + + if not sharded_strategy.can_handle_sharded_objects: + validate_sharded_objects_handling(sharded_strategy, common_strategy) + sharded_objects_state_dict, sharded_state_dict = extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, ShardedObject) + ) + common_strategy.save_sharded_objects(sharded_objects_state_dict, checkpoint_dir) + + def metadata_finalize_fn(): + if torch.distributed.get_rank() == 0: + save_config( + CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), + checkpoint_dir, + ) + torch.distributed.barrier() + + if not async_sharded_save: + sharded_strategy.save(sharded_state_dict, checkpoint_dir) + metadata_finalize_fn() + return + + if not isinstance(sharded_strategy, AsyncSaveShardedStrategy): + raise CheckpointingException( + f'Cannot apply async_save to non-async strategy {sharded_strategy}' + ) + async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir) + async_request.finalize_fns.append(metadata_finalize_fn) + return async_request + + +def get_default_save_sharded_strategy( + backend: str = 'torch_dist', version: int = 1 +) -> SaveShardedStrategy: + """Get default save sharded strategy.""" + return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version) + + +def get_default_save_common_strategy( + backend: str = 'torch', version: int = 1 +) -> SaveCommonStrategy: + """Get default save common strategy.""" + return get_default_strategy(StrategyAction.SAVE_COMMON, backend, version) + + +def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy: + """Get default load sharded strategy.""" + return verify_checkpoint_and_load_strategy(checkpoint_dir)[0] diff --git a/megatron/core/dist_checkpointing/state_dict_transformation.py b/megatron/core/dist_checkpointing/state_dict_transformation.py deleted file mode 100644 index c8f01dd4a2c96ac809eee627bf83c1a78cefa9d6..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/state_dict_transformation.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Utilities for transforming state_dict, including a tensor-aware implementation.""" - -import logging -from time import time -from typing import Any, Callable, Optional - -import torch - -from .dict_utils import dict_list_map_inplace, extract_matching_values, merge, nested_values -from .exchange_utils import determine_main_replica_uniform_distribution, exchange_by_distribution -from .mapping import ( - CommonStateDict, - ShardedObject, - ShardedStateDict, - ShardedTensor, - ShardedTensorFactory, - StateDict, - apply_factories, - apply_factory_merges, -) -from .utils import ( - _sharded_object_id, - _sharded_tensor_shard_id, - extract_nonpersistent, - extract_sharded_base, -) -from .validation import determine_global_metadata, validate_sharding_integrity - -logger = logging.getLogger(__name__) - - -def save_preprocess( - sharded_state_dict: ShardedStateDict, - validate_access_integrity: bool = True, - preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None, -): - """Preprocesses the given state dictionary by applying factories, - discarding non-persistent data and extracting the common state dictionary. - Optionally, it can validate sharding integrity. - - Args: - sharded_state_dict (ShardedStateDict): The initial state dictionary to be preprocessed. - validate_access_integrity (bool): If True, triggers validation of sharding integrity. - preprocess_common_before_consistancy_check (callable, None): A callable function - that will preprocess the common state dict (i.e can be used to remove keys - that we expect to be different in the state dict) - - Returns: - Tuple[ShardedStateDict, dict]: - The preprocessed sharded state dictionary and the common state dictionary. - """ - apply_factories(sharded_state_dict) - _, sharded_state_dict = extract_nonpersistent(sharded_state_dict) - sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict) - if validate_access_integrity: - preprocessed_common_state_dict = common_state_dict - if preprocess_common_before_consistancy_check: - preprocessed_common_state_dict = preprocess_common_before_consistancy_check( - common_state_dict - ) - validate_sharding_integrity( - determine_global_metadata(sharded_part)[1], - common_state_dict=preprocessed_common_state_dict, - ) - return sharded_part, common_state_dict - - -def load_preprocess(sharded_state_dict: ShardedStateDict): - """Preprocesses the given state dictionary by applying factories - and extracting non-persistent data, without modifying the original dictionary. - - Args: - sharded_state_dict (ShardedStateDict): - The initial state dictionary to be processed (remains unchanged). - - Returns: - Tuple[ShardedStateDict, dict, dict]: - - A preprocessed copy of the sharded state dictionary. - - A dictionary containing non-persistent state data. - - A dictionary of `ShardedTensorFactory` instances. - """ - # Create a copy of sharded_state_dict as the passed in state dict may have - # references that prevent tensors from being deallocated - sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True) - - sh_ten_factories, _ = extract_matching_values( - sharded_state_dict, - lambda x: isinstance(x, ShardedTensorFactory), - return_lists_as_dicts=True, - ) - apply_factories(sharded_state_dict) - - # Data inside sh_ten_factories no longer needed so delete them to reduce memory usage - dict_list_map_inplace(ShardedTensorFactory.without_data, sh_ten_factories) - # Non-persistent objects - nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict) - dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict) - return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories - - -def prepare_state_dict_for_save( - sharded_state_dict: ShardedStateDict, - async_prepare: bool = False, - algo: str = 'atomic', - validate_access_integrity: bool = True, - parallelization_group: Optional[torch.distributed.ProcessGroup] = None, - to_cpu: bool = True, -): - """Creates a tensor-aware state dictionary that can be saved using the Local Checkpoint Manager. - - Args: - sharded_state_dict (ShardedStateDict): The initial state dictionary. - async_prepare (bool): If True, enables asynchronous preparation. - algo (str): The algorithm used to create the tensor-aware state dictionary. - validate_access_integrity (bool): If True, validates sharding integrity. - parallelization_group (torch.distributed.ProcessGroup): - The process group used for exchanges to avoid duplications. - to_cpu (bool): If True, moves all tensors from device to CPU. - - Returns: - ShardedStateDict: The tensor-aware state dictionary. - """ - - _start = time() - - if async_prepare: - raise NotImplementedError('Async state_dict preparation is not yet implemented') - if algo != 'atomic' and algo != 'fully_parallel': - raise NotImplementedError( - 'Only "atomic" and "fully_parallel" sharding algorithms are supported.' - ) - fully_parallel = algo == 'fully_parallel' - - sharded_part, common_state_dict = save_preprocess(sharded_state_dict, validate_access_integrity) - sharded_tensors = [] - sharded_objects = [] - for sh_base in nested_values(sharded_part): - if isinstance(sh_base, ShardedTensor): - sharded_tensors.append(sh_base) - else: - assert isinstance(sh_base, ShardedObject) - sharded_objects.append(sh_base) - if fully_parallel: - shard_to_saving_rank, _, shard_to_metadata = determine_main_replica_uniform_distribution( - sharded_part, parallelization_group, True - ) - - raw_tensors, raw_objects = {}, {} - for ten in sharded_tensors: - shard_id = _sharded_tensor_shard_id(ten) - if not fully_parallel or shard_to_saving_rank[shard_id] == torch.distributed.get_rank(): - # TODO cover creating copies on host in CheckpointManager.save() - if to_cpu: - raw_tensors[shard_id] = ten.data.to("cpu", non_blocking=True) - else: - raw_tensors[shard_id] = ten.data - ten.data = None - for obj in sharded_objects: - raw_objects[_sharded_object_id(obj)] = obj.data - obj.data = None - - logger.debug(f'prepare_state_dict_for_save took {time() - _start}') - - state_dict_for_save = { - 'raw_tensors': raw_tensors, - 'raw_objects': raw_objects, - 'common': common_state_dict, - 'sharded_state_dict': sharded_part, - } - if fully_parallel: - state_dict_for_save['shard_to_rank'] = shard_to_saving_rank - state_dict_for_save['shard_to_metadata'] = shard_to_metadata - return state_dict_for_save - - -def recreate_state_dict_after_load( - sharded_state_dict: ShardedStateDict, - loaded_state_dict: ShardedStateDict, - algo: str = 'atomic', - exchange_algo: str = 'broadcast', - validate_access_integrity: bool = True, - parallelization_group: Optional[torch.distributed.ProcessGroup] = None, -): - """Creates a final sharded state dictionary from a tensor-aware state dictionary. - - Args: - sharded_state_dict (ShardedStateDict): - The initial sharded state dictionary generated from the model. - loaded_state_dict (ShardedStateDict): - Tensor-aware state dictionary used to fill in missing data in the sharded state. - algo (str): The algorithm used to reconstruct the state dictionary - from the tensor-aware state dictionary. - exchange_algo (str): The algorithm used for tensor exchanges during retrieval. - validate_access_integrity (bool): If True, performs validation of sharding integrity. - parallelization_group (torch.distributed.ProcessGroup): - The process group used for efficient exchanges during retrieval. - - Returns: - ShardedStateDict: The finalized sharded state dictionary. - """ - - if algo != 'atomic' and algo != 'fully_parallel': - raise NotImplementedError( - 'Only "atomic" and "fully_parallel" sharding algorithms are supported.' - ) - fully_parallel = algo == 'fully_parallel' - - # __adding__ common part - recreated_state_dict, _ = extract_matching_values(loaded_state_dict["common"], lambda x: True) - - if not sharded_state_dict: - return recreated_state_dict - # TODO validate laoded_state_dict["sharded_state_dict"] and sharded_state_dict are compatible - - sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( - sharded_state_dict - ) - # __adding__ nonpersistent part - merge(recreated_state_dict, nonpersistent_state_dict) - - sharded_part, _ = extract_sharded_base(sharded_state_dict) - if validate_access_integrity: - validate_sharding_integrity(determine_global_metadata(sharded_part)[1]) - - # load sharded tensors and sharded objects to sharded_part - loaded_tensors = loaded_state_dict['raw_tensors'] - # TODO cover restoring the original device (H2D) in CheckpointManager.load() - for k, v in loaded_tensors.items(): - loaded_tensors[k] = v.cuda() # H2D - if fully_parallel: - distribution = ( - loaded_state_dict['shard_to_rank'], - None, - loaded_state_dict['shard_to_metadata'], - ) - unloaded_shards = {} - for sh_base in nested_values(sharded_part): - if isinstance(sh_base, ShardedTensor): - shard_id = _sharded_tensor_shard_id(sh_base) - if shard_id not in loaded_tensors: - unloaded_shards[shard_id] = sh_base - loaded_tensors = exchange_by_distribution( - loaded_tensors, unloaded_shards, distribution, parallelization_group, exchange_algo - ) - loaded_objects = loaded_state_dict['raw_objects'] - - def load_sharded_base(x: Any): - if isinstance(x, ShardedTensor): - shard_id = _sharded_tensor_shard_id(x) - if shard_id not in loaded_tensors: - raise Exception( - 'The current local checkpoint implementation assumes' - 'consistent tensor sharding during load and save operations.' - f'However, the expected shard {x} (ID: {shard_id})' - f'was not found in the checkpoint. (IDs: {loaded_tensors.keys()})' - ) - x = loaded_tensors[shard_id] - if isinstance(x, ShardedObject): - object_id = _sharded_object_id(x) - assert object_id in loaded_objects, (x, object_id, loaded_objects.keys()) - x = loaded_objects[object_id] - return x - - dict_list_map_inplace(load_sharded_base, sharded_part) - sharded_part = apply_factory_merges(sharded_part, sh_ten_factories) - # __adding__ sharded_part - merge(recreated_state_dict, sharded_part) - return recreated_state_dict diff --git a/megatron/core/dist_checkpointing/state_dict_utils.py b/megatron/core/dist_checkpointing/state_dict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..74de4fd6667c618d30c2aee37779ab40567b3e62 --- /dev/null +++ b/megatron/core/dist_checkpointing/state_dict_utils.py @@ -0,0 +1,112 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Utilities for transforming state_dict.""" + +from typing import Callable, Union + +from .dict_utils import dict_list_map_inplace, extract_matching_values +from .mapping import ( + CommonStateDict, + ShardedStateDict, + ShardedTensor, + ShardedTensorFactory, + StateDict, + apply_factories, +) +from .utils import extract_nonpersistent, extract_sharded_base +from .validation import determine_global_metadata, validate_sharding_integrity + + +def save_preprocess( + sharded_state_dict: ShardedStateDict, + validate_access_integrity: bool = True, + preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None, +): + """Preprocesses the given state dictionary by applying factories, + discarding non-persistent data and extracting the common state dictionary. + Optionally, it can validate sharding integrity. + + Args: + sharded_state_dict (ShardedStateDict): The initial state dictionary to be preprocessed. + validate_access_integrity (bool): If True, triggers validation of sharding integrity. + preprocess_common_before_consistancy_check (callable, None): A callable function + that will preprocess the common state dict (i.e can be used to remove keys + that we expect to be different in the state dict) + + Returns: + Tuple[ShardedStateDict, dict]: + The preprocessed sharded state dictionary and the common state dictionary. + """ + apply_factories(sharded_state_dict) + _, sharded_state_dict = extract_nonpersistent(sharded_state_dict) + sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict) + sharded_part = filter_out_empty_flatten_tensor(sharded_part) + if validate_access_integrity: + preprocessed_common_state_dict = common_state_dict + if preprocess_common_before_consistancy_check: + preprocessed_common_state_dict = preprocess_common_before_consistancy_check( + common_state_dict + ) + validate_sharding_integrity( + determine_global_metadata(sharded_part)[1], + common_state_dict=preprocessed_common_state_dict, + ) + return sharded_part, common_state_dict + + +def load_preprocess(sharded_state_dict: ShardedStateDict): + """Preprocesses the given state dictionary by applying factories + and extracting non-persistent data, without modifying the original dictionary. + + Args: + sharded_state_dict (ShardedStateDict): + The initial state dictionary to be processed (remains unchanged). + + Returns: + Tuple[ShardedStateDict, dict, dict]: + - A preprocessed copy of the sharded state dictionary. + - A dictionary containing non-persistent state data. + - A dictionary of `ShardedTensorFactory` instances. + """ + # Create a copy of sharded_state_dict as the passed in state dict may have + # references that prevent tensors from being deallocated + sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True) + sharded_state_dict = filter_out_empty_flatten_tensor(sharded_state_dict) + + sh_ten_factories, _ = extract_matching_values( + sharded_state_dict, + lambda x: isinstance(x, ShardedTensorFactory), + return_lists_as_dicts=True, + ) + apply_factories(sharded_state_dict) + + # Data inside sh_ten_factories no longer needed so delete them to reduce memory usage + dict_list_map_inplace(ShardedTensorFactory.without_data, sh_ten_factories) + # Non-persistent objects + nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict) + dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict) + return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories + + +def filter_out_empty_flatten_tensor(sharded_state_dict: Union[dict, list]): + """ + Filter out ShardedTensors with empty flatten_range. + These tensors can cause the PyTorch check in failure. + + Args: + sharded_state_dict: state dict possibly containing ShardedTensor objects + """ + # Filter out ShardedTensors with empty flatten_range. + # These tensors can cause the PyTorch check in + # `TorchShardedTensor._init_from_local_shards_and_global_metadata` to fail. + # This situation may occur in custom Fully Sharded Data Parallel (FSDP) cases. + sharded_state_dict, _ = extract_matching_values( + sharded_state_dict, + lambda v: not ( + isinstance(v, ShardedTensor) + and v.flattened_range + and v.flattened_range.start == v.flattened_range.stop + ), + ) + + return sharded_state_dict diff --git a/megatron/core/dist_checkpointing/strategies/async_utils.py b/megatron/core/dist_checkpointing/strategies/async_utils.py index 7cdda8ac329f302e7ed8fbf404ee397f145fd024..29b2d5fb97e9c8c9bb93b14ea7bd5542df5a286e 100644 --- a/megatron/core/dist_checkpointing/strategies/async_utils.py +++ b/megatron/core/dist_checkpointing/strategies/async_utils.py @@ -1,224 +1,543 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -""" -This module provides an async utilities which allow to start -a checkpoint save process in the background. -""" -import logging -from collections import deque -from time import time -from typing import Callable, List, NamedTuple, Optional, Tuple - -import torch -from torch import multiprocessing as mp - -logger = logging.getLogger(__name__) - - -class AsyncRequest(NamedTuple): - """Represents an async request that needs to be scheduled for execution. - - Args: - async_fn (Callable, optional): async function to call. None represents noop. - async_fn_args (Tuple): args to pass to `async_fn`. - finalize_fns (List[Callable]): list of functions to call to finalize the request. - These functions will be called synchronously after `async_fn` is done - *on all ranks*. - """ - - async_fn: Optional[Callable] - async_fn_args: Tuple - finalize_fns: List[Callable] - is_frozen: bool = False - - def add_finalize_fn(self, fn: Callable) -> None: - """Adds a new finalize function to the request. - - Args: - fn (Callable): function to add to the async request. This function - will be called *after* existing finalization functions. - - Returns: - None - """ - if self.is_frozen: - raise RuntimeError('Cannot add finalization functions to a frozen AsyncRequest') - self.finalize_fns.append(fn) - - def execute_sync(self) -> None: - """Helper to synchronously execute the request. - - This logic is equivalent to what should happen in case of the async call. - """ - if self.async_fn is not None: - self.async_fn(*self.async_fn_args) - torch.distributed.barrier() - for finalize_fn in self.finalize_fns: - finalize_fn() - - def freeze(self) -> 'AsyncRequest': - """Freezes the async request, disallowing adding new finalization functions. - - Returns: - AsyncRequest: new async request with all same fields except for the - `is_frozen` flag. - """ - return self._replace(is_frozen=True) - - -class DistributedAsyncCaller: - """Wrapper around mp.Process that ensures correct semantic of distributed finalization. - - Starts process asynchronously and allows checking if all processes on all ranks are done. - """ - - def __init__(self): - self.process: Optional[mp.Process] = None - self.start_time: Optional[float] = None - - def schedule_async_call(self, async_fn: Optional[Callable], save_args: Tuple) -> None: - """Spawn a process with `async_fn` as the target. - - This method must be called on all ranks. - - Args: - async_fn (Callable, optional): async function to call. If None, - no process will be started. - save_args (Tuple): async function args. - """ - if async_fn is None: - return # nothing to do - start_sync = time() - torch.cuda.synchronize() - end_sync = time() - logger.debug( - f"rank: {torch.distributed.get_rank()}, takes {end_sync - start_sync} to finish D2H " - ) - - ctx = mp.get_context('fork') - self.start_time = time() - self.process = ctx.Process(target=async_fn, args=save_args) - self.process.start() - init_time = time() - logger.debug( - f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} to schedule async ckpt " - ) - - def is_current_async_call_done(self, blocking=False) -> bool: - """Check if async save is finished on all ranks. - - For semantic correctness, requires rank synchronization in each check. - This method must be called on all ranks. - - Args: - blocking (bool, optional): if True, will wait until the call is done - on all ranks. Otherwise, returns immediately if at least one rank - is still active. Defaults to False. - - Returns: - bool: True if all ranks are done (immediately of after active wait - if `blocking` is True), False if at least one rank is still active. - """ - # The following takes the same overhead as torch.distributed.barrier (single integer all-reduce) - is_alive = int(self.process.is_alive()) if self.process is not None else 0 - ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device()) - logger.debug( - f"rank: {torch.distributed.get_rank()}, DistributedAsyncCaller is_alive: {is_alive}" - ) - torch.distributed.all_reduce(ten) - if ten[0] > 0 and not blocking: - return False - else: - if self.process is not None: - logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process") - self.process.join() - self.process = None - - logger.debug( - f"DistributedAsyncCaller: Async process join finished after {time() - self.start_time:.2f}s from forking" - ) - self.start_time = None - return True - - -class _ActiveAsyncRequest(NamedTuple): - """Helper to represent an active async call. - - Args: - idx (int): index of the call (starting from 0) - async_caller (DistributedAsyncCaller): async caller instance that represents - the async process handling the async request - async_request (AsyncRequest): async request that is being called - """ - - idx: int - async_caller: DistributedAsyncCaller - async_request: AsyncRequest - - -class AsyncCallsQueue: - """Manages a queue of async calls. - - Allows adding a new async call with `schedule_async_request` and finalizing - active calls with `maybe_finalize_async_calls`. - """ - - def __init__(self): - self.async_calls: deque[_ActiveAsyncRequest] = deque([]) - self.call_idx: int = -1 - - def schedule_async_request(self, async_request: AsyncRequest) -> int: - """Start a new async call and add it to a queue of active async calls. - - This method must be called on all ranks. - - Args: - async_request (AsyncRequest): async request to start. - - Returns: - int: index of the async call that was started. - This can help the user keep track of the async calls. - """ - self.call_idx += 1 - async_caller = DistributedAsyncCaller() - async_request = async_request.freeze() - async_caller.schedule_async_call(async_request.async_fn, async_request.async_fn_args) - self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request)) - return self.call_idx - - def maybe_finalize_async_calls(self, blocking=False) -> List[int]: - """Finalizes all available calls. - - This method must be called on all ranks. - - Args: - blocking (bool, optional): if True, will wait until all active requests - are done. Otherwise, finalizes only the async request that already - finished. Defaults to False. - Returns: - List[int]: list of indices (as returned by `schedule_async_request`) - of async calls that have been successfully finalized. - """ - call_idx_finalized = [] - while self.async_calls: - next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(blocking) - if not next_async_done: - break - call_idx, _, async_request = self.async_calls.popleft() - for finalize_fn in async_request.finalize_fns: - finalize_fn() - ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device()) - torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX) - assert ( - ten.item() == call_idx - ), 'Unmatched async calls. That probably means not all ranks are participating in async finalization' - call_idx_finalized.append(call_idx) - return call_idx_finalized - - def get_num_unfinalized_calls(self): - """Get the number of active async calls.""" - return len(self.async_calls) - - def close(self): - """Finalize all calls upon closing.""" - self.maybe_finalize_async_calls(blocking=True) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +This module provides an async utilities which allow to start +a checkpoint save process in the background. +""" +import gc +import logging +from abc import ABC, abstractmethod +from collections import deque +from contextlib import contextmanager +from queue import Empty +from time import sleep, time +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple + +import torch +from torch import multiprocessing as mp + +from ..utils import debug_time + +logger = logging.getLogger(__name__) + + +@contextmanager +def _disable_gc(): + """Temporarily disables GC.""" + gc_enabled = gc.isenabled() + try: + if gc_enabled: + gc.disable() + yield + finally: + if gc_enabled: + gc.enable() + + +class AsyncRequest(NamedTuple): + """Represents an async request that needs to be scheduled for execution. + + Args: + async_fn (Callable, optional): async function to call. None represents noop. + async_fn_args (Tuple): args to pass to `async_fn`. + finalize_fns (List[Callable]): list of functions to call to finalize the request. + These functions will be called synchronously after `async_fn` is done + *on all ranks*. + async_fn_kwargs (Tuple): kwargs to pass to `async_fn`. + preload_fn (Callable): preload function to stage tensors from GPU to Host. + This should be self-contained with a proper list of arguments with `partial`. + is_frozen (Bool): a flag to indicate this async request can be modified or not. + call_idx (int): index variable used to order async requests for synchronization + in preloading and writing tensors on the async caller + + """ + + async_fn: Optional[Callable] + async_fn_args: Tuple + finalize_fns: List[Callable] + async_fn_kwargs: Dict = {} + preload_fn: Callable = None + is_frozen: bool = False + call_idx: int = 0 + + def add_finalize_fn(self, fn: Callable) -> None: + """Adds a new finalize function to the request. + + Args: + fn (Callable): function to add to the async request. This function + will be called *after* existing finalization functions. + + Returns: + None + """ + if self.is_frozen: + raise RuntimeError('Cannot add finalization functions to a frozen AsyncRequest') + self.finalize_fns.append(fn) + + def execute_sync(self) -> None: + """Helper to synchronously execute the request. + + This logic is equivalent to what should happen in case of the async call. + """ + if self.async_fn is not None: + self.async_fn(*self.async_fn_args) + torch.distributed.barrier() + for finalize_fn in self.finalize_fns: + finalize_fn() + + def freeze(self) -> 'AsyncRequest': + """Freezes the async request, disallowing adding new finalization functions. + + Returns: + AsyncRequest: new async request with all same fields except for the + `is_frozen` flag. + """ + return self._replace(is_frozen=True) + + +class AsyncCaller(ABC): + """Wrapper around mp.Process that ensures correct semantic of distributed finalization. + + Starts process asynchronously and allows checking if all processes on all ranks are done. + """ + + @abstractmethod + def schedule_async_call(self, async_req: AsyncRequest) -> None: + """Schedule `async_req` with some process forking or reusing + persistent worker + + This method must be called on all ranks. + + Args: + async_req (AsyncRequest): `AsyncRequest` object containing to + start async process + """ + raise NotImplementedError("This should be implemented") + + @abstractmethod + def is_current_async_call_done(self, blocking: bool, no_dist: bool) -> bool: + """Check if async save is finished on all ranks. + + For semantic correctness, requires rank synchronization in each check. + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until the call is done + on all ranks. Otherwise, returns immediately if at least one rank + is still active. Defaults to False. + no_dist (bool, Optional): if True, training ranks simply check its + asynchronous checkpoint writer without synchronization. + + Returns: + bool: True if all ranks are done (immediately of after active wait + if `blocking` is True), False if at least one rank is still active. + + """ + raise NotImplementedError("This should be implemented") + + def sync_all_async_calls(self, is_alive: int) -> bool: + """Check if all ranks have completed async checkpoint writing + + Args: + is_alive (bool): if True, the current async request is not completed + + Returns: + bool: True if all ranks are done, False if at least one rank is still active. + + """ + ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device()) + torch.distributed.all_reduce(ten) + return ten[0] == 0 + + @abstractmethod + def close(self): + """Terminate the async caller at exit of an application or some termination conditions""" + logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller") + + def __del__(self): + self.close() + + +class TemporalAsyncCaller(AsyncCaller): + """Wrapper around mp.Process that ensures correct semantic of distributed finalization. + + Starts process asynchronously and allows checking if all processes on all ranks are done. + """ + + def __init__(self): + self.process: Optional[mp.Process] = None + self.start_time: Optional[float] = None + + @_disable_gc() + def schedule_async_call(self, async_req: AsyncRequest) -> None: + """Spawn a process with `async_fn` as the target. + + This method must be called on all ranks. + + Args: + async_fn (Callable, optional): async function to call. If None, + no process will be started. + async_req (AsyncRequest): `AsyncRequest` object containing to + start async process + """ + if async_req.async_fn is None: + return # nothing to do + + async_fn_args = list(async_req.async_fn_args) + if async_req.preload_fn: + # If there's a preload_fn in `async_req`, we call this func + # to do the defined action in `async_req.preload_fn` to + # stage GPU tensors to its defined destination + async_fn_args[1] = async_req.preload_fn() + + rank = torch.distributed.get_rank() + start_sync = time() + torch.cuda.synchronize() + end_sync = time() + logger.debug(f"rank: {rank}, takes {end_sync - start_sync} to finish D2H ") + + ctx = mp.get_context('fork') + self.start_time = time() + self.process = ctx.Process( + target=async_req.async_fn, args=async_fn_args, kwargs=async_req.async_fn_kwargs + ) + self.process.start() + init_time = time() + logger.debug(f"rank: {rank}, takes {init_time - self.start_time} to schedule async ckpt ") + + def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool: + """Check if async save is finished on all ranks. + + For semantic correctness, requires rank synchronization in each check. + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until the call is done + on all ranks. Otherwise, returns immediately if at least one rank + is still active. Defaults to False. + no_dist (bool, Optional): if True, training ranks simply check its + asynchronous checkpoint writer without synchronization. + + Returns: + bool: True if all ranks are done (immediately of after active wait + if `blocking` is True), False if at least one rank is still active. + """ + # The following takes the same overhead + # as torch.distributed.barrier (single integer all-reduce) + is_alive = int(self.process.is_alive()) if self.process is not None else 0 + is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive) + + if not is_done and blocking: + self.close() + is_done = True + return is_done + + def close(self): + if self.process: + logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process") + self.process.join() + self.process = None + logger.debug( + "TemporalAsyncCaller: Async process join finished " + f"after {time() - self.start_time:.2f}s from forking" + ) + self.start_time = None + + +class PersistentAsyncCaller(AsyncCaller): + """Wrapper around mp.Process that ensures correct semantic of distributed finalization. + + Starts process asynchronously and allows checking if all processes on all ranks are done. + """ + + def __init__(self): + self.process: mp.Process = None + self.start_time: Optional[float] = None + ctx = mp.get_context('spawn') + # main queue to deliver `AsyncRequest` from host to the ckpt worker + self.queue: mp.JoinableQueue = ctx.JoinableQueue() + # Queue used to synchronize for the completion of preloading tensors to host + # between a trainer and ckpt worker + self.preload_q: mp.JoinableQueue = ctx.JoinableQueue() + # Queue used to inform trainer when the saving is completed + self.comp_q: mp.Queue = ctx.Queue() + self.cur_item: int = None + self.cur_idx: int = -1 + + def schedule_async_call(self, async_req: AsyncRequest) -> None: + """Put `AsyncRequest` to the Persistent Async Caller + + This method must be called on all ranks. + + Args: + async_fn (Callable, optional): async function to call. If None, + no process will be started. + async_req (AsyncRequest): `AsyncRequest` object containing to + schedule a checkpointing request + """ + if async_req.async_fn is None: + return # nothing to do + + start_sync = end_sync = None + + self.start_time = time() + if self.process is None: + ctx = mp.get_context('spawn') + logger.info( + f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Starting Async Caller" + ) + self.process: mp.Process = ctx.Process( + target=PersistentAsyncCaller.async_loop, + args=( + torch.distributed.get_rank(), + self.queue, + self.preload_q, + self.comp_q, + logger.getEffectiveLevel(), + ), + ) + self.process.start() + logger.info( + f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Started Async Caller" + ) + + if async_req.preload_fn: + self.preload_q.put(async_req.call_idx) + self.queue.put(async_req) + logger.debug(f"rank: {torch.distributed.get_rank()}, put {async_req.call_idx}") + + if async_req.preload_fn: + start_sync = time() + # Synchronize for pre-staging tensors + self.preload_q.join() + end_sync = time() + logger.debug( + f"rank: {torch.distributed.get_rank()}, " + f"takes {end_sync - start_sync} to finish D2H " + ) + + init_time = time() + logger.debug( + f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} " + "to schedule async ckpt " + ) + + def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool: + """Check if async save is finished on all ranks. + + For semantic correctness, requires rank synchronization in each check. + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until the call is done + on all ranks. Otherwise, returns immediately if at least one rank + is still active. Defaults to False. + no_dist (bool, Optional): if True, training ranks simply check its + asynchronous checkpoint writer without synchronization. + + Returns: + bool: True if all ranks are done (immediately of after active wait + if `blocking` is True), False if at least one rank is still active. + """ + + is_alive: bool = False + + if self.process: + while self.cur_item is None: + try: + # Retrieve comp call_idx without waiting + self.cur_item = self.comp_q.get_nowait() + except Empty: + # This method is called after any `AsyncRequest` is pushed to the main loop + # So, the background writing is still active + # before the worker put call_idx to `comp_q` + if not blocking: + is_alive = True + break + sleep(0.1) + + if self.cur_item is not None: + logger.debug( + f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}" + f" is completed, {is_alive}" + ) + + is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive) + # This is set to False when blocking == False so this routine is called again + # to simply call `sync_all_async_calls` to check if other ranks complete the writing + if is_done: + # The current request is completed globally. Reset the current item for polling. + logger.debug( + f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}" + f" is completed globally, {is_done}" + ) + self.cur_item = None + + return is_done + + def close(self): + logger.info( + f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller" + ) + if self.process: + self.queue.put('DONE') + self.queue.join() + self.process.join() + self.process = None + + @staticmethod + @_disable_gc() + def async_loop( + rank: int, + queue: mp.JoinableQueue, + preload_q: mp.JoinableQueue, + comp_q: mp.Queue, + log_level: int = logging.INFO, + ): + """Main function for the persistent checkpoint worker + + The persisent worker is created once and terminated at exit or + when application calls `close()` explictily + + This routine receives `AsyncRequest` and does `preload_fn` first and + put the integer value in `preload_q` to inform the trainer to proceed. + When the `async_fn` from the request` is completed (background saving is done), + it puts a integer value to `comp_q` to notify the trainer the completion. + + Args: + rank (int): the rank of the trainer where the persistent worker is created. + queue (mp.JoinableQueue): the main queue used to receive `AsyncRequest + from the training rank + preload_q (mp.JoinableQueue): a queue to inform trainer that preloading of tensors + from GPU to Host or dedicated location is completed + comp_q (mp.Queue): a queue to inform the training rank the completion of scheduled + async checkpoint request + log_level (int, Optional): an integer to set log-level in this spawned process + to get aligned with the training rank's logging level + + """ + logger = logging.getLogger(__name__) + logger.setLevel(log_level) + logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has started") + while True: + item = queue.get() + if isinstance(item, str) and item == 'DONE': + queue.task_done() + break + elif isinstance(item, AsyncRequest): + async_fn_args = list(item.async_fn_args) + if item.preload_fn: + call_idx = preload_q.get() + # the 2nd arg is state dict + async_fn_args[1] = item.preload_fn() + logger.debug(f"{rank} has completed D2H of {call_idx}") + preload_q.task_done() + item.async_fn(*async_fn_args, **item.async_fn_kwargs) + logger.debug(f"{rank} has completed saving {item.call_idx}") + comp_q.put(item.call_idx) + queue.task_done() + + logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has terminated") + + +class _ActiveAsyncRequest(NamedTuple): + """Helper to represent an active async call. + + Args: + idx (int): index of the call (starting from 0) + async_caller (DistributedAsyncCaller): async caller instance that represents + the async process handling the async request + async_request (AsyncRequest): async request that is being called + """ + + idx: int + async_caller: AsyncCaller + async_request: AsyncRequest + + +class AsyncCallsQueue: + """Manages a queue of async calls. + + Allows adding a new async call with `schedule_async_request` and finalizing + active calls with `maybe_finalize_async_calls`. + """ + + def __init__(self, persistent: bool = False): + self.async_calls: deque[_ActiveAsyncRequest] = deque([]) + self.call_idx: int = -1 + self.persistent: bool = persistent + self.persistent_caller: AsyncCaller = None + + def _get_async_caller(self): + if not self.persistent: + return TemporalAsyncCaller() + if self.persistent_caller is None: + self.persistent_caller = PersistentAsyncCaller() + return self.persistent_caller + + def schedule_async_request(self, async_request: AsyncRequest) -> int: + """Start a new async call and add it to a queue of active async calls. + + This method must be called on all ranks. + + Args: + async_request (AsyncRequest): async request to start. + + Returns: + int: index of the async call that was started. + This can help the user keep track of the async calls. + """ + self.call_idx += 1 + async_caller = self._get_async_caller() + # Backward compatibility for local checkpointing built with the old AsyncRequest + if len(async_request._fields) != len(AsyncRequest._fields): + async_request = AsyncRequest(**async_request._asdict()) + + async_request = async_request._replace(call_idx=self.call_idx) + finalize_fns = async_request.finalize_fns + async_request = async_request._replace(finalize_fns=None) + async_request = async_request.freeze() + async_caller.schedule_async_call(async_request) + self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, finalize_fns)) + return self.call_idx + + def maybe_finalize_async_calls(self, blocking=False, no_dist=False) -> List[int]: + """Finalizes all available calls. + + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until all active requests + are done. Otherwise, finalizes only the async request that already + finished. Defaults to False. + Returns: + List[int]: list of indices (as returned by `schedule_async_request`) + of async calls that have been successfully finalized. + """ + call_idx_finalized = [] + while self.async_calls: + next_async_done = self.async_calls[0].async_caller.is_current_async_call_done( + blocking, no_dist + ) + if not next_async_done: + break + with debug_time("finalize", logger): + call_idx, _, finalize_fns = self.async_calls.popleft() + ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device()) + torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX) + assert ten.item() == call_idx, 'Unmatched async calls. ' + 'That probably means not all ranks are participating in async finalization' + for finalize_fn in finalize_fns: + finalize_fn() + call_idx_finalized.append(call_idx) + return call_idx_finalized + + def get_num_unfinalized_calls(self): + """Get the number of active async calls.""" + return len(self.async_calls) + + def close(self): + """Finalize all calls upon closing.""" + self.maybe_finalize_async_calls(blocking=True) + if self.persistent and self.persistent_caller: + self.persistent_caller.close() diff --git a/megatron/core/dist_checkpointing/strategies/base.py b/megatron/core/dist_checkpointing/strategies/base.py index cdcdd49f446d76afa2bc5857b8d623ce589c5393..7409a6a4d363fd5a1e1e95281e8bef1e05dd66c0 100644 --- a/megatron/core/dist_checkpointing/strategies/base.py +++ b/megatron/core/dist_checkpointing/strategies/base.py @@ -1,227 +1,228 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Strategies base interfaces. """ - -from abc import ABC, abstractmethod -from collections import defaultdict -from enum import Enum -from pathlib import Path -from typing import Any, DefaultDict, Union - -from ..mapping import CheckpointingException, ShardedStateDict, StateDict -from .async_utils import AsyncCallsQueue, AsyncRequest - - -class StrategyAction(Enum): - """Specifies save vs load and sharded vs common action.""" - - LOAD_COMMON = 'load_common' - LOAD_SHARDED = 'load_sharded' - SAVE_COMMON = 'save_common' - SAVE_SHARDED = 'save_sharded' - - -default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict) - -async_calls = AsyncCallsQueue() - - -def get_default_strategy(action: StrategyAction, backend: str, version: int): - """Retrieves a default strategy for a given action, backend and version.""" - try: - if backend == 'zarr': - error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages' - from .tensorstore import register_default_tensorstore_strategies - - register_default_tensorstore_strategies() - from .zarr import register_default_zarr_strategies - - register_default_zarr_strategies() - elif backend == 'torch_dist': - error_hint = ' Please use PyTorch version >=2.1' - from .torch import register_default_torch_strategies - - register_default_torch_strategies() - except ImportError as e: - raise CheckpointingException( - f'Cannot import a default strategy for: {(action.value, backend, version)}. ' - f'Error: {e}. Hint: {error_hint}' - ) from e - try: - return default_strategies[action.value][(backend, version)] - except KeyError as e: - raise CheckpointingException( - f'Cannot find a default strategy for: {(action.value, backend, version)}' - ) from e - - -def register_default_strategy( - action: StrategyAction, - backend: str, - version: int, - strategy: Union['SaveStrategyBase', 'LoadStrategyBase'], -): - """Adds a given strategy to the registry of default strategies. - - Args: - action (StrategyAction): specifies save/load and sharded/common - backend (str): backend that the strategy becomes a default for - version (int): version that the strategy becomes a default for - strategy (SaveStrategyBase, LoadStrategyBase): strategy to register - """ - default_strategies[action.value][(backend, version)] = strategy - - -class LoadStrategyBase(ABC): - """Base class for a load strategy. Requires implementing checks for compatibility with a - given checkpoint version.""" - - @abstractmethod - def check_backend_compatibility(self, loaded_backend): - """Verifies if this strategy is compatible with `loaded_backend`.""" - raise NotImplementedError - - @abstractmethod - def check_version_compatibility(self, loaded_version): - """Verifies if this strategy is compatible with `loaded_version`.""" - raise NotImplementedError - - @property - def can_handle_sharded_objects(self): - """Returns whether or not this strategy can handle loading ShardedObjects.""" - return False - - -class SaveStrategyBase(ABC): - """Base class for a save strategy. Requires defining a backend type and - version of the saved format.""" - - def __init__(self, backend: str, version: int): - self.backend = backend - self.version = version - - @property - def can_handle_sharded_objects(self): - """Returns whether or not this strategy can handle saving ShardedObjects.""" - return False - - def __str__(self): - return f'{self.__class__.__name__}({self.backend}, {self.version})' - - -class LoadCommonStrategy(LoadStrategyBase): - """Load strategy for common (non-sharded) objects""" - - @abstractmethod - def load_common(self, checkpoint_dir: Path): - """Load common part of the checkpoint.""" - raise NotImplementedError - - @abstractmethod - def load_sharded_objects( - self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path - ): - """Load sharded objects from the checkpoint.""" - raise NotImplementedError - - def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: - """Load just the metadata from the checkpoint.""" - if not self.can_handle_sharded_objects: - return {} - raise NotImplementedError - - -class LoadShardedStrategy(LoadStrategyBase): - """Load strategy for sharded tensors""" - - @abstractmethod - def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - """Load the sharded part of the checkpoint.""" - raise NotImplementedError - - @abstractmethod - def load_tensors_metadata(self, checkpoint_dir: Path): - """Load tensors metadata from the checkpoint for ShardedTensors. - - Returns a dictionary similar to a sharded state dict, but note that - the dictionary keys are simply ShardedTensor keys (contrary to the - actual sharded state dicts where keys correspond to state dict keys). - - Dict values are ShardedTensors without any data and sharding (so, the - only useful information is tensors global shape and dtype). - """ - raise NotImplementedError( - f'Loading only tensors metadata not implemented for {self.__class__.__name__}' - ) - - def load_sharded_metadata(self, checkpoint_dir: Path): - """Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects. - - Returns a dictionary similar to a sharded state dict, but note that - the dictionary keys are simply sharded keys (contrary to the - actual sharded state dicts where keys correspond to state dict keys). - - Dict values are ShardedTensors or ShardedObjects without any data and sharding. - """ - if not self.can_handle_sharded_objects: - return self.load_tensors_metadata(checkpoint_dir) - raise NotImplementedError( - f'Loading only sharded metadata not implemented for {self.__class__.__name__}' - ) - - def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str): - """Remove all tensors whose key starts with key_prefix""" - raise NotImplementedError - - -class SaveCommonStrategy(SaveStrategyBase): - """Save strategy for common (non-sharded) objects""" - - @abstractmethod - def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path): - """Save common part of the state dict.""" - raise NotImplementedError - - def save_sharded_objects( - self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path - ): - """Save sharded objects from the state dict.""" - raise NotImplementedError - - -class SaveShardedStrategy(SaveStrategyBase): - """Save strategy for sharded tensors""" - - @abstractmethod - def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - """Save the sharded part of the state dict.""" - raise NotImplementedError - - -class AsyncSaveShardedStrategy(SaveShardedStrategy): - """Save strategy suitable for async save.""" - - @abstractmethod - def async_save( - self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path - ) -> AsyncRequest: - """Perform preparation and return an AsyncRequest to the external caller. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict to save - checkpoint_dir (Path): checkpoint target directory - - Returns: - AsyncRequest: represents the async save function and finalization function. - It is the caller responsibility to actually schedule the async save. - """ - raise NotImplementedError - - def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - """Each async strategy can be trivially used as a sync strategy.""" - async_request = self.async_save(sharded_state_dict, checkpoint_dir) - # multiprocessing routines may cause issue when called on parent process - # We keep this verbose call for now - global async_calls - async_calls.schedule_async_request(async_request) - async_calls.maybe_finalize_async_calls(blocking=True) +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Strategies base interfaces. """ + +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import Enum +from pathlib import Path +from typing import Any, DefaultDict, Union + +from ..mapping import CheckpointingException, ShardedStateDict, StateDict +from .async_utils import AsyncCallsQueue, AsyncRequest + + +class StrategyAction(Enum): + """Specifies save vs load and sharded vs common action.""" + + LOAD_COMMON = 'load_common' + LOAD_SHARDED = 'load_sharded' + SAVE_COMMON = 'save_common' + SAVE_SHARDED = 'save_sharded' + + +default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict) + +async_calls = AsyncCallsQueue() + + +def get_default_strategy(action: StrategyAction, backend: str, version: int): + """Retrieves a default strategy for a given action, backend and version.""" + error_hint: str = None + try: + if backend == 'zarr': + error_hint = ' Please install `zarr` and `tensorstore!=0.1.46` packages' + from .tensorstore import register_default_tensorstore_strategies + + register_default_tensorstore_strategies() + from .zarr import register_default_zarr_strategies + + register_default_zarr_strategies() + elif backend == 'torch_dist': + error_hint = ' Please use PyTorch version >=2.1' + from .torch import register_default_torch_strategies + + register_default_torch_strategies() + except ImportError as e: + raise CheckpointingException( + f'Cannot import a default strategy for: {(action.value, backend, version)}. ' + f'Error: {e}. Hint: {error_hint}' + ) from e + try: + return default_strategies[action.value][(backend, version)] + except KeyError as e: + raise CheckpointingException( + f'Cannot find a default strategy for: {(action.value, backend, version)}' + ) from e + + +def register_default_strategy( + action: StrategyAction, + backend: str, + version: int, + strategy: Union['SaveStrategyBase', 'LoadStrategyBase'], +): + """Adds a given strategy to the registry of default strategies. + + Args: + action (StrategyAction): specifies save/load and sharded/common + backend (str): backend that the strategy becomes a default for + version (int): version that the strategy becomes a default for + strategy (SaveStrategyBase, LoadStrategyBase): strategy to register + """ + default_strategies[action.value][(backend, version)] = strategy + + +class LoadStrategyBase(ABC): + """Base class for a load strategy. Requires implementing checks for compatibility with a + given checkpoint version.""" + + @abstractmethod + def check_backend_compatibility(self, loaded_backend): + """Verifies if this strategy is compatible with `loaded_backend`.""" + raise NotImplementedError + + @abstractmethod + def check_version_compatibility(self, loaded_version): + """Verifies if this strategy is compatible with `loaded_version`.""" + raise NotImplementedError + + @property + def can_handle_sharded_objects(self): + """Returns whether or not this strategy can handle loading ShardedObjects.""" + return False + + +class SaveStrategyBase(ABC): + """Base class for a save strategy. Requires defining a backend type and + version of the saved format.""" + + def __init__(self, backend: str, version: int): + self.backend = backend + self.version = version + + @property + def can_handle_sharded_objects(self): + """Returns whether or not this strategy can handle saving ShardedObjects.""" + return False + + def __str__(self): + return f'{self.__class__.__name__}({self.backend}, {self.version})' + + +class LoadCommonStrategy(LoadStrategyBase): + """Load strategy for common (non-sharded) objects""" + + @abstractmethod + def load_common(self, checkpoint_dir: Path): + """Load common part of the checkpoint.""" + raise NotImplementedError + + @abstractmethod + def load_sharded_objects( + self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path + ): + """Load sharded objects from the checkpoint.""" + raise NotImplementedError + + def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: + """Load just the metadata from the checkpoint.""" + if not self.can_handle_sharded_objects: + return {} + raise NotImplementedError + + +class LoadShardedStrategy(LoadStrategyBase): + """Load strategy for sharded tensors""" + + @abstractmethod + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Load the sharded part of the checkpoint.""" + raise NotImplementedError + + @abstractmethod + def load_tensors_metadata(self, checkpoint_dir: Path): + """Load tensors metadata from the checkpoint for ShardedTensors. + + Returns a dictionary similar to a sharded state dict, but note that + the dictionary keys are simply ShardedTensor keys (contrary to the + actual sharded state dicts where keys correspond to state dict keys). + + Dict values are ShardedTensors without any data and sharding (so, the + only useful information is tensors global shape and dtype). + """ + raise NotImplementedError( + f'Loading only tensors metadata not implemented for {self.__class__.__name__}' + ) + + def load_sharded_metadata(self, checkpoint_dir: Path): + """Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects. + + Returns a dictionary similar to a sharded state dict, but note that + the dictionary keys are simply sharded keys (contrary to the + actual sharded state dicts where keys correspond to state dict keys). + + Dict values are ShardedTensors or ShardedObjects without any data and sharding. + """ + if not self.can_handle_sharded_objects: + return self.load_tensors_metadata(checkpoint_dir) + raise NotImplementedError( + f'Loading only sharded metadata not implemented for {self.__class__.__name__}' + ) + + def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str): + """Remove all tensors whose key starts with key_prefix""" + raise NotImplementedError + + +class SaveCommonStrategy(SaveStrategyBase): + """Save strategy for common (non-sharded) objects""" + + @abstractmethod + def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path): + """Save common part of the state dict.""" + raise NotImplementedError + + def save_sharded_objects( + self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path + ): + """Save sharded objects from the state dict.""" + raise NotImplementedError + + +class SaveShardedStrategy(SaveStrategyBase): + """Save strategy for sharded tensors""" + + @abstractmethod + def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Save the sharded part of the state dict.""" + raise NotImplementedError + + +class AsyncSaveShardedStrategy(SaveShardedStrategy): + """Save strategy suitable for async save.""" + + @abstractmethod + def async_save( + self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path + ) -> AsyncRequest: + """Perform preparation and return an AsyncRequest to the external caller. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to save + checkpoint_dir (Path): checkpoint target directory + + Returns: + AsyncRequest: represents the async save function and finalization function. + It is the caller responsibility to actually schedule the async save. + """ + raise NotImplementedError + + def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Each async strategy can be trivially used as a sync strategy.""" + async_request = self.async_save(sharded_state_dict, checkpoint_dir) + # multiprocessing routines may cause issue when called on parent process + # We keep this verbose call for now + global async_calls + async_calls.schedule_async_request(async_request) + async_calls.maybe_finalize_async_calls(blocking=True) diff --git a/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py b/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..a6fc2c5a8aea51e64249520a1603e2643c2280ff --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +""" FS Reader with metadata cached support. """ + +import os +from typing import Union + +from torch.distributed.checkpoint import FileSystemReader, Metadata + + +class CachedMetadataFileSystemReader(FileSystemReader): + """ + Extends FileSystemReader to cache metadata for improved performance. + + Attributes: + _cached_metadata (Metadata or None): Cached metadata from the file system. + """ + + def __init__(self, path: Union[str, os.PathLike]) -> None: + """ + Initialize with file system path. + + Args: + path (Union[str, os.PathLike]): Path to the checkpoint directory or file. + """ + super().__init__(path=path) + self._cached_metadata = None + + def read_metadata(self) -> Metadata: + """ + Read metadata from file system, caching for subsequent calls. + + Returns: + Metadata: Checkpoint metadata. + """ + if self._cached_metadata is None: + self._cached_metadata = super().read_metadata() + return self._cached_metadata diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py index f2c87b4d606c161d2989de4a1e3daa1177c51c57..8e51b5c8d75579e96b33d89374572fea1d3a342f 100644 --- a/megatron/core/dist_checkpointing/strategies/common.py +++ b/megatron/core/dist_checkpointing/strategies/common.py @@ -1,157 +1,157 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -""" Common strategies. """ - -import logging -import os -from pathlib import Path - -import torch - -from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict -from megatron.core.dist_checkpointing.strategies.base import ( - SaveCommonStrategy, - StrategyAction, - register_default_strategy, -) - -from ..dict_utils import dict_list_map_inplace, nested_values -from ..mapping import CheckpointingException, ShardedObject, is_main_replica -from ..strategies.base import LoadCommonStrategy - -COMMON_STATE_FNAME = 'common.pt' - -logger = logging.getLogger(__name__) - - -def register_default_common_strategies(): - """Register default common strategies.""" - register_default_strategy(StrategyAction.LOAD_COMMON, 'torch', 1, TorchCommonLoadStrategy()) - register_default_strategy( - StrategyAction.SAVE_COMMON, 'torch', 1, TorchCommonSaveStrategy('torch', 1) - ) - - -class TorchCommonSaveStrategy(SaveCommonStrategy): - """Common save strategy leveraging native torch save/load.""" - - def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path): - """Save common part of the state dict.""" - if torch.distributed.get_rank() == 0: - torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME) - - def save_sharded_objects( - self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path - ): - """Save sharded objects from the state dict.""" - for sh_obj in nested_values(sharded_objects_state_dict): - if is_main_replica(sh_obj.replica_id): - save_path = checkpoint_dir / f'{sh_obj.unique_key}.pt' - os.makedirs(save_path.parent, exist_ok=True) - torch.save(sh_obj.data, save_path) - - def can_handle_sharded_objects(self): - """This strategy can handle ShardedObjects.""" - return True - - -class TorchCommonLoadStrategy(LoadCommonStrategy): - """Common load strategy leveraging native torch save/load.""" - - def load_common(self, checkpoint_dir: Path): - """Load common (non-sharded) objects state dict from the checkpoint. - - Args: - checkpoint_dir (Path): checkpoint directory - - Returns: - StateDict: state dict with non-sharded objects from the checkpoint - """ - load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME - try: - return torch.load(load_path, map_location='cpu') - except FileNotFoundError as e: - err_msg = f'Common file {load_path} does not exist' - ckpt_files = [f.name for f in checkpoint_dir.iterdir()] - logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}') - raise CheckpointingException(err_msg) from e - - def load_sharded_objects( - self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path - ): - """Replaces all ShardedObject from a given state dict with values loaded from the - checkpoint. - - Args: - sharded_objects_state_dict (ShardedStateDict): - sharded state dict defining what objects should be loaded. - checkpoint_dir (Path): checkpoint directory - - Returns: - None: sharded state dict is modified in place - """ - - def load_sharded_object(sh_obj: ShardedObject): - sh_obj.data = None - load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt' - try: - loaded_obj = torch.load(load_path) - except FileNotFoundError as e: - # Backward compatible logic: previously the save format was incorrect - old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') - try: - loaded_obj = torch.load(old_load_path) - except FileNotFoundError: - err_msg = f'Object shard {load_path} not found' - obj_subdir = checkpoint_dir / sh_obj.key - if obj_subdir.exists(): - obj_files = [f.name for f in obj_subdir.iterdir()] - logger.debug( - f'{err_msg}. Object {sh_obj.key} directory content: {obj_files}' - ) - else: - ckpt_files = [f.name for f in checkpoint_dir.iterdir()] - logger.debug( - f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint' - f' directory content: {ckpt_files}' - ) - raise CheckpointingException(err_msg) from e - return loaded_obj - - return dict_list_map_inplace(load_sharded_object, sharded_objects_state_dict) - - def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: - sharded_metadata = {} - for subdir in checkpoint_dir.iterdir(): - if not subdir.is_dir(): - continue - shard_files = list(subdir.glob('shard_*.pt')) - if not shard_files: - continue - sh_objs = [] - for shard_file in shard_files: - full_key = f'{subdir.name}/{shard_file.stem}' - sh_objs.append(ShardedObject.empty_from_unique_key(full_key)) - - # This is a backward-compatibility fix, where the last global shape is missing in the - # name - if sh_objs[0].global_shape[-1] < 0: - max_last_offset = max(map(lambda sh_obj: sh_obj.global_offset[-1], sh_objs)) - for sh_obj in sh_objs: - sh_obj.global_shape = (*sh_obj.global_shape[:-1], max_last_offset + 1) - - # Update the sharded state dict - for sh_obj in sh_objs: - sharded_metadata[sh_obj.unique_key] = sh_obj - return sharded_metadata - - @property - def can_handle_sharded_objects(self): - """This strategy can handle ShardedObjects.""" - return True - - def check_backend_compatibility(self, loaded_version): - pass - - def check_version_compatibility(self, loaded_version): - pass +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" Common strategies. """ + +import logging +import os +from pathlib import Path + +import torch + +from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict +from megatron.core.dist_checkpointing.strategies.base import ( + SaveCommonStrategy, + StrategyAction, + register_default_strategy, +) + +from ..dict_utils import dict_list_map_inplace, nested_values +from ..mapping import CheckpointingException, ShardedObject, is_main_replica +from ..strategies.base import LoadCommonStrategy + +COMMON_STATE_FNAME = 'common.pt' + +logger = logging.getLogger(__name__) + + +def register_default_common_strategies(): + """Register default common strategies.""" + register_default_strategy(StrategyAction.LOAD_COMMON, 'torch', 1, TorchCommonLoadStrategy()) + register_default_strategy( + StrategyAction.SAVE_COMMON, 'torch', 1, TorchCommonSaveStrategy('torch', 1) + ) + + +class TorchCommonSaveStrategy(SaveCommonStrategy): + """Common save strategy leveraging native torch save/load.""" + + def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path): + """Save common part of the state dict.""" + if torch.distributed.get_rank() == 0: + torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME) + + def save_sharded_objects( + self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path + ): + """Save sharded objects from the state dict.""" + for sh_obj in nested_values(sharded_objects_state_dict): + if is_main_replica(sh_obj.replica_id): + save_path = checkpoint_dir / f'{sh_obj.unique_key}.pt' + os.makedirs(save_path.parent, exist_ok=True) + torch.save(sh_obj.data, save_path) + + def can_handle_sharded_objects(self): + """This strategy can handle ShardedObjects.""" + return True + + +class TorchCommonLoadStrategy(LoadCommonStrategy): + """Common load strategy leveraging native torch save/load.""" + + def load_common(self, checkpoint_dir: Path): + """Load common (non-sharded) objects state dict from the checkpoint. + + Args: + checkpoint_dir (Path): checkpoint directory + + Returns: + StateDict: state dict with non-sharded objects from the checkpoint + """ + load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME + try: + return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + ckpt_files = [f.name for f in checkpoint_dir.iterdir()] + logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}') + raise CheckpointingException(err_msg) from e + + def load_sharded_objects( + self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path + ): + """Replaces all ShardedObject from a given state dict with values loaded from the + checkpoint. + + Args: + sharded_objects_state_dict (ShardedStateDict): + sharded state dict defining what objects should be loaded. + checkpoint_dir (Path): checkpoint directory + + Returns: + None: sharded state dict is modified in place + """ + + def load_sharded_object(sh_obj: ShardedObject): + sh_obj.data = None + load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt' + try: + loaded_obj = torch.load(load_path, weights_only=False) + except FileNotFoundError as e: + # Backward compatible logic: previously the save format was incorrect + old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') + try: + loaded_obj = torch.load(old_load_path, weights_only=False) + except FileNotFoundError: + err_msg = f'Object shard {load_path} not found' + obj_subdir = checkpoint_dir / sh_obj.key + if obj_subdir.exists(): + obj_files = [f.name for f in obj_subdir.iterdir()] + logger.debug( + f'{err_msg}. Object {sh_obj.key} directory content: {obj_files}' + ) + else: + ckpt_files = [f.name for f in checkpoint_dir.iterdir()] + logger.debug( + f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint' + f' directory content: {ckpt_files}' + ) + raise CheckpointingException(err_msg) from e + return loaded_obj + + return dict_list_map_inplace(load_sharded_object, sharded_objects_state_dict) + + def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: + sharded_metadata = {} + for subdir in checkpoint_dir.iterdir(): + if not subdir.is_dir(): + continue + shard_files = list(subdir.glob('shard_*.pt')) + if not shard_files: + continue + sh_objs = [] + for shard_file in shard_files: + full_key = f'{subdir.name}/{shard_file.stem}' + sh_objs.append(ShardedObject.empty_from_unique_key(full_key)) + + # This is a backward-compatibility fix, where the last global shape is missing in the + # name + if sh_objs[0].global_shape[-1] < 0: + max_last_offset = max(map(lambda sh_obj: sh_obj.global_offset[-1], sh_objs)) + for sh_obj in sh_objs: + sh_obj.global_shape = (*sh_obj.global_shape[:-1], max_last_offset + 1) + + # Update the sharded state dict + for sh_obj in sh_objs: + sharded_metadata[sh_obj.unique_key] = sh_obj + return sharded_metadata + + @property + def can_handle_sharded_objects(self): + """This strategy can handle ShardedObjects.""" + return True + + def check_backend_compatibility(self, loaded_version): + pass + + def check_version_compatibility(self, loaded_version): + pass diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py index 47ab4d112638fe18b1ab5c1cdf63a7fbfc5d0ae1..ef7053b61b0016aa46df6caa0fb57bd3a2b3c5ba 100644 --- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py +++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py @@ -1,439 +1,496 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -""" Storage writer for PyT Distributed format allowing asynchronous save. """ -import gc -import logging -import os -import queue -from contextlib import contextmanager -from itertools import chain -from pathlib import Path -from time import time -from typing import Callable, Dict, List, Optional, Tuple, Union - -import psutil -import torch -from torch import multiprocessing as mp -from torch.distributed.checkpoint import FileSystemWriter -from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item -from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType -from torch.distributed.checkpoint.storage import WriteResult -from torch.futures import Future - -logger = logging.getLogger(__name__) - -WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file - -_results_queue = None - - -def _get_write_results_queue(): - global _results_queue - if _results_queue is None: - ctx = mp.get_context('spawn') - _results_queue = ctx.Manager().Queue() - return _results_queue - - -@contextmanager -def _disable_gc(): - """Temporarily disables GC.""" - gc_enabled = gc.isenabled() - try: - if gc_enabled: - gc.disable() - yield - finally: - if gc_enabled: - gc.enable() - - -class FileSystemWriterAsync(FileSystemWriter): - """ - Async-enabled implementation of FileSystemWriter using file IO. - - This class doesn't spawn the async process itself, relies on the external async mechanism. - - Flow: - 1. Call `write_data` - 2. Externally start async process with `get_save_function_and_args` function and args - 3. The async function to call is `writer_proxy_func` which calls - `write_preloaded_data` in multiple processes - - After saving is finalized on all ranks: - 4. Call `super().finish` with the results gathered in `self.writer_result` - - Note that step (3) above can also be called synchronously. - - Currently, it's assumed that a separate writer is created for each ckpt save - (intermediate state is stored as writer attributes). - """ - - def __init__(self, *args, separation_hint: Optional[str] = None, **kwargs): - super().__init__(*args, **kwargs) - if not self.single_file_per_rank: - raise NotImplementedError( - 'single_file_per_rank flag not supported for FileSystemWriterAsync' - ) - - # Intermediate state between preparation and finalization - self.write_buckets: Optional[List[WriteBucket]] = None - self.results_queue: Optional[mp.Queue] = None - self.separation_hint = separation_hint - - def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None: - """ - First stage of async saving. Copy data to CPU and plan the local saving. - - Args: - plan (SavePlan): save plan generated by the PyT Distributed compatible planner - planner (SavePlanner): save planner used to resolve the bytes and tensor data - - Returns: None, but stores the save plan in `self.write_buckets` - """ - storage_plan: _StoragePrefix = plan.storage_data - start = time() - logger.debug(f"thread_count: {self.thread_count}, time: {start}") - if self.separation_hint: - assert ( - self.thread_count > 1 - ), "thread_count must be at least 2 if separation_hint is provided" - bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count - item_buckets = _split_by_size_and_type(bins, plan.items, self.separation_hint) - logger.debug(f"bucket_prep, time: {time() - start}") - - start = time() - # move tensors from GPU to CPU before starting async writing - # We do D2H synchronously for now - file_count = 0 - - def gen_file(prefix=""): - nonlocal file_count - file_name = f"{prefix}{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" - file_count += 1 - return file_name - - # Prepare bytes / tensor data in each bucket, which will be assigned to each writer process - self.write_buckets = [] - for group_name, group_buckets in _split_by_separation_hint( - item_buckets, self.separation_hint - ).items(): - for bucket in group_buckets: - bytes_data = [ - (item, planner.resolve_data(item)) - for item in bucket - if item.type == WriteItemType.BYTE_IO - ] - tensor_data = [ - (item, planner.resolve_data(item).detach().to("cpu", non_blocking=True)) - for item in bucket - if item.type != WriteItemType.BYTE_IO - ] - if len(bytes_data) > 0 or len(tensor_data) > 0: - file_name = gen_file(prefix=group_name) - self.write_buckets.append( - (self.path / file_name, file_name, (bytes_data, tensor_data)) - ) - - # Check if there is anything to write on this rank - if len(self.write_buckets) > 0: - assert len(self.write_buckets) <= self.thread_count, ( - len(self.write_buckets), - self.thread_count, - ) - self.results_queue = _get_write_results_queue() - else: - self.results_queue = None - end = time() - logger.debug(f"D2H and push, time: {end - start}") - - def get_save_function_and_args(self) -> Tuple[Optional[Callable], Tuple]: - """ - Get function that saves the data to storage along with its arguments. - Allows the external caller to apply the save function synchronously or asynchronously. - - Returns: None (if there is nothing to write on this rank) or a tuple of: - - the function that saves the data - - arguments to that function - """ - if not self.write_buckets: - return None, () - return (self.write_preloaded_data_multiproc, (self.write_buckets, self.results_queue)) - - @staticmethod - @_disable_gc() - def write_preloaded_data_multiproc( - write_buckets: List[WriteBucket], global_results_queue: mp.Queue - ) -> None: - """ - Performs saving data to storage with multiple processes. - - Starts predefined number of processes and uses 2 queues to make sure the results - are complete: - - local_results_queue - to send the actual results - - count_queue - small queue to mark worker as completed - - Using just one queue disallowed proper exception handling. - - This method is meant to be run in a forked subprocess. - Triggering GC during execution leads to CUDA errors - (cleaning up tensors owned by the parent process). - To prevent this, we disable the GC explicitly for this function with _disable_gc. - - Args: - write_buckets (List[WriteBucket]): write plan - global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]] - (or an Exception) from parallel write processes to the main training process - Returns: None - """ - w_start = time() - write_results_or_exc: Union[dict, Exception] = dict() - ctx = mp.get_context('fork') - local_results_queue = ctx.Queue() - count_queue = ctx.JoinableQueue() - p_list = [] - for i, write_bucket in enumerate(write_buckets): - try: - count_queue.put(i) - p_list.append( - ctx.Process( - target=FileSystemWriterAsync.write_preloaded_data, - args=(i, write_bucket, local_results_queue, count_queue, True), - ) - ) - except Exception as e: - err_msg = f'An error is caught while a proc {i} is created, error: {e}' - logger.error(err_msg) - write_results_or_exc = RuntimeError(err_msg) - - if not isinstance(write_results_or_exc, Exception): - for p in p_list: - p.start() - - logger.debug('FileSystemWriterAsync: collecting worker results...') - - # To make sure all nodes are completed - count_queue.join() - # At this point, all workers completed, so the queue should have exactly - # `len(write_buckets)` items - for proc_idx in range(len(write_buckets)): - try: - local_proc_idx, local_results_or_exc = local_results_queue.get() - except queue.Empty: - write_results_or_exc = RuntimeError( - f'Unexpected empty `local_results_queue`' - f' (got only {proc_idx}/{len(write_buckets)} items)' - ) - break - else: - if isinstance(local_results_or_exc, Exception): - err_msg = ( - f"Local process {local_proc_idx} encountered" - f" an error: {local_results_or_exc}" - ) - logger.error(err_msg) - write_results_or_exc = local_results_or_exc - break - else: - assert isinstance(local_results_or_exc, list), type(local_results_or_exc) - write_results_or_exc[local_proc_idx] = local_results_or_exc - p_list[local_proc_idx].join() - - logger.debug('FileSystemWriterAsync: collected worker results successfully') - - global_results_queue.put(write_results_or_exc) - - w_end = time() - logger.debug( - f"{w_end}, rank: {torch.distributed.get_rank()}," - f" write(sync,parallel): {w_end - w_start}" - ) - - @staticmethod - @_disable_gc() - def write_preloaded_data( - local_proc_idx: int, - write_bucket: WriteBucket, - results_queue: mp.SimpleQueue, - count_queue: mp.JoinableQueue, - use_fsync: bool, - ) -> None: - """ - Performs actual data saving to storage. - - Args: - local_proc_idx (int): index of a local process that performs writing - write_bucket (WriteBucket): data to write to storage - results_queue (mp.Queue): queue to return the write results - to the proxy checkpoint process. - count_queue (mp.JoinableQueue): queue to marks worker task as completed - use_fsync (bool): if True, calls os.fsync at the end of saving - - Returns: None, the write result are put into the `queue` - """ - mem_before = _process_memory() - - local_results = [] - try: - file_name, storage_key, (bytes_data, tensor_data) = write_bucket - with open(file_name, "wb") as stream: - for write_item, data in bytes_data: - local_results.append(_write_item(stream, data, write_item, storage_key)) - - for write_item, tensor in tensor_data: - assert tensor.is_cpu - local_results.append(_write_item(stream, tensor, write_item, storage_key)) - - if use_fsync: - os.fsync(stream.fileno()) - local_output = (local_proc_idx, local_results) - except Exception as e: - local_output = (local_proc_idx, e) - - results_queue.put(local_output) - # Signal this process is done. - count_queue.get() - count_queue.task_done() - - mem_after = _process_memory() - logger.debug( - f"{local_proc_idx} consumed: {mem_after - mem_before}," - f" before: {mem_before}, after: {mem_after}" - ) - - def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]: - """Write all items from ``plan``.""" - raise NotImplementedError('write_data not implemented for FileSystemWriterAsync') - - def retrieve_write_results(self) -> List[WriteResult]: - """ - Turn the latest dict including write results from `self.results_queue` - into a single results lists. Includes error check. - - Returns (List[WriteResult]): the list of write results - from all local processes performing the save. - - """ - assert self.write_buckets is not None - - if self.results_queue is None: - write_results_or_exc = {} - else: - try: - write_results_or_exc = self.results_queue.get_nowait() - except queue.Empty: - raise RuntimeError(f'results_queue should not be empty') - - if isinstance(write_results_or_exc, Exception): - raise RuntimeError(f'Worker failure: {write_results_or_exc}') from write_results_or_exc - write_results: dict = write_results_or_exc - if len(write_results) != len(self.write_buckets): - raise RuntimeError( - f'Incomplete worker results (expected {len(self.write_buckets)},' - f' got {len(write_results)}. This probably indicates a worker failure.' - ) - return list(chain.from_iterable(write_results.values())) - - -def _split_by_size_and_type( - bins: int, items: List[WriteItem], separation_hint: Optional[str] = None -) -> List[List[WriteItem]]: - """ - Splits write items according to item size into close to uniform bins. - - Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type, - but with a fixed _item_size function. - - Args: - bins (int): numbers of bins to split to - items (List[WriteItem]): list of write items - - Returns (List[List[WriteItem]]): write items split to bins - """ - if bins == 1: - return [items] - - bytes_items = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] - tensor_items = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] - - buckets: List[List[WriteItem]] = [[] for _ in range(bins)] - bucket_sizes = [0 for _ in range(bins)] - - tensor_items.sort(key=_item_size, reverse=True) - - # Assign bytes with a simple round-robin - for i, item in enumerate(bytes_items): - buckets[i % bins].append(item) - - # Then, assign tensors according to their sizes - for item in tensor_items: - # TODO replace with headq - idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0] - buckets[idx].append(item) - bucket_sizes[idx] += _item_size(item) - - return buckets - - -def _split_by_separation_hint( - buckets: List[List[WriteItem]], separation_hint: Optional[str] = None -) -> Dict[str, List[List[WriteItem]]]: - """ - Splits buckets into those whose keys begin with the separation_hint and those whose keys do not - - Args: - buckets (List[List[WriteItem]]): buckets to split - separation_hint (Optional[str]): optional prefix to split on - - Returns (Dict[str, List[List[WriteItem]]]): a dictionary - mapping the prefix to the relevant buckets - """ - bins = len(buckets) - buckets_with_separation_hint = {} - if separation_hint is not None: - buckets_default = [[] for _ in range(bins)] - buckets_hint = [[] for _ in range(bins)] - for i in range(bins): - for item in buckets[i]: - if item.index.fqn.startswith(separation_hint): - buckets_hint[i].append(item) - else: - buckets_default[i].append(item) - buckets_with_separation_hint[""] = buckets_default - buckets_with_separation_hint[separation_hint] = buckets_hint - else: - buckets_with_separation_hint[""] = buckets - return buckets_with_separation_hint - - -def _item_size(item: WriteItem) -> int: - """ - Calculates size (in bytes) of a single write item. - - Same as torch.distributed.checkpoint.filesystem._item_size, - but fixes computing chunk size (with item.tensor_data.chunk.sizes) - - Args: - item (WriteItem): write item to compute the size of - - Returns (int): size of an item in bytes - """ - size = 1 - assert item.tensor_data is not None - # can't use math.prod as PT needs to support older python - for s in item.tensor_data.chunk.sizes: - size *= s - - dtype = item.tensor_data.properties.dtype - return size * torch._utils._element_size(dtype) - - -def _process_memory() -> int: - """ - Get memory used by current process. - - Returns (int): memory used by current process - """ - process = psutil.Process(os.getpid()) - mem_info = process.memory_info() - return mem_info.rss +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +""" Storage writer for PyT Distributed format allowing asynchronous save. """ +import dataclasses +import logging +import os +import queue +from functools import partial +from heapq import heappop, heappush +from itertools import chain +from operator import itemgetter +from pathlib import Path +from time import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import psutil +import torch +from torch import multiprocessing as mp +from torch.distributed.checkpoint import FileSystemWriter +from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item +from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType +from torch.distributed.checkpoint.storage import WriteResult +from torch.futures import Future + +from .async_utils import _disable_gc + +logger = logging.getLogger(__name__) + +WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file + +_results_queue = None + + +def _get_write_results_queue(): + global _results_queue + if _results_queue is None: + ctx = mp.get_context('spawn') + _results_queue = ctx.Manager().Queue() + return _results_queue + + +class FileSystemWriterAsync(FileSystemWriter): + """ + Async-enabled implementation of FileSystemWriter using file IO. + + This class doesn't spawn the async process itself, relies on the external async mechanism. + + Flow: + 1. Call `write_data` + 2. Externally start async process with `get_save_function_and_args` function and args + 3. The async function to call is `writer_proxy_func` which calls + `write_preloaded_data` in multiple processes + + After saving is finalized on all ranks: + 4. Call `super().finish` with the results gathered in `self.writer_result` + + Note that step (3) above can also be called synchronously. + + Currently, it's assumed that a separate writer is created for each ckpt save + (intermediate state is stored as writer attributes). + """ + + def __init__(self, *args, separation_hint: Optional[str] = None, **kwargs): + super().__init__(*args, **kwargs) + if not self.single_file_per_rank: + raise NotImplementedError( + 'single_file_per_rank flag not supported for FileSystemWriterAsync' + ) + + self.can_run_decentralized_global_plan: bool = True + + # Intermediate state between preparation and finalization + self.write_buckets: Optional[List[WriteBucket]] = None + self.results_queue: Optional[mp.Queue] = None + self.separation_hint = separation_hint + + def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None: + """ + First stage of async saving. Copy data to CPU and plan the local saving. + + Args: + plan (SavePlan): save plan generated by the PyT Distributed compatible planner + planner (SavePlanner): save planner used to resolve the bytes and tensor data + + Returns: None, but stores the save plan in `self.write_buckets` + """ + storage_plan: _StoragePrefix = plan.storage_data + start = time() + logger.debug(f"thread_count: {self.thread_count}, time: {start}") + if self.separation_hint: + assert ( + self.thread_count > 1 + ), "thread_count must be at least 2 if separation_hint is provided" + bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count + item_buckets = _split_by_size_and_type(bins, plan.items) + logger.debug(f"bucket_prep, time: {time() - start}") + + start = time() + # move tensors from GPU to CPU before starting async writing + # We do D2H synchronously for now + file_count = 0 + + def gen_file(prefix=""): + nonlocal file_count + file_name = f"{prefix}{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" + file_count += 1 + return file_name + + def _clone_if_needed(ten: torch.Tensor): + """Clone if we detect incontiguous storage for CPU tensors + + Makes sure we perform a `clone` only if we detect incontiguous storage, + so that we don't blow up host memory unnecessarily. + + TODO: For persistent worker, this work should be changed to move the cpu tensor + to shared_memory. + """ + ten = ten.detach() + if ten.device.type != "cpu": + # We do D2H later when the async_request is scheduled for both sync / async + # checkpointing + return ten + is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize + return ten.clone() if is_view else ten + + # Prepare bytes / tensor data in each bucket, which will be assigned to each writer process + self.write_buckets = [] + for group_name, group_buckets in _split_by_separation_hint( + item_buckets, self.separation_hint + ).items(): + for bucket in group_buckets: + bytes_data = [ + (item, planner.resolve_data(item)) + for item in bucket + if item.type == WriteItemType.BYTE_IO + ] + tensor_data = [ + (item, _clone_if_needed(planner.resolve_data(item))) + for item in bucket + if item.type != WriteItemType.BYTE_IO + ] + if len(bytes_data) > 0 or len(tensor_data) > 0: + file_name = gen_file(prefix=group_name) + self.write_buckets.append( + (self.path / file_name, file_name, (bytes_data, tensor_data)) + ) + + # Check if there is anything to write on this rank + if len(self.write_buckets) > 0: + assert len(self.write_buckets) <= self.thread_count, ( + len(self.write_buckets), + self.thread_count, + ) + self.results_queue = _get_write_results_queue() + else: + self.results_queue = None + end = time() + logger.debug(f"D2H and push, time: {end - start}") + + def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Callable], List]: + """ + Get function that saves the data to storage along with its arguments. + Allows the external caller to apply the save function synchronously or asynchronously. + + Returns: None (if there is nothing to write on this rank) or a tuple of: + 1) the function that saves the data. + 2) the function that stages the GPU tensors to a destination for async checkpointing. + This function should be self-contained. + 3) arguments to that function in 1). + """ + if not self.write_buckets: + return None, None, () + return ( + self.write_preloaded_data_multiproc, + partial(self.preload_tensors, self.write_buckets, True), + [torch.distributed.get_rank(), self.write_buckets, self.results_queue], + ) + + @staticmethod + def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List[WriteBucket]: + """Preload tensors in state_dict to host memory through CPU memory + Args: + write_buckets(List): List of `WriteBucket`, + which includes what to be saved in a checkpoint + non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True. + """ + result = [] + + for bucket in write_buckets: + file_name, storage_key, (bytes_data, tensor_data) = bucket + tensor_data = [ + (item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data + ] + result.append((file_name, storage_key, (bytes_data, tensor_data))) + if non_blocking: + torch.cuda.synchronize() + return result + + @staticmethod + @_disable_gc() + def write_preloaded_data_multiproc( + rank, write_buckets: List[WriteBucket], global_results_queue: mp.Queue + ) -> None: + """ + Performs saving data to storage with multiple processes. + + Starts predefined number of processes and uses 2 queues to make sure the results + are complete: + - local_results_queue - to send the actual results + - count_queue - small queue to mark worker as completed + + Using just one queue disallowed proper exception handling. + + This method is meant to be run in a forked subprocess. + Triggering GC during execution leads to CUDA errors + (cleaning up tensors owned by the parent process). + To prevent this, we disable the GC explicitly for this function with _disable_gc. + + Args: + write_buckets (List[WriteBucket]): write plan + global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]] + (or an Exception) from parallel write processes to the main training process + Returns: None + """ + logger = logging.getLogger(__name__) + w_start = time() + write_results_or_exc: Union[dict, Exception] = dict() + ctx = mp.get_context('fork') + local_results_queue = ctx.Queue() + count_queue = ctx.JoinableQueue() + p_list = [] + for i, write_bucket in enumerate(write_buckets): + try: + count_queue.put(i) + p_list.append( + ctx.Process( + target=FileSystemWriterAsync.write_preloaded_data, + args=(i, write_bucket, local_results_queue, count_queue, True), + ) + ) + except Exception as e: + err_msg = f'An error is caught while a proc {i} is created, error: {e}' + logger.error(err_msg) + write_results_or_exc = RuntimeError(err_msg) + + if not isinstance(write_results_or_exc, Exception): + for p in p_list: + p.start() + + logger.debug('FileSystemWriterAsync: collecting worker results...') + + # To make sure all nodes are completed + count_queue.join() + # At this point, all workers completed, so the queue should have exactly + # `len(write_buckets)` items + for proc_idx in range(len(write_buckets)): + try: + local_proc_idx, local_results_or_exc = local_results_queue.get() + except queue.Empty: + write_results_or_exc = RuntimeError( + f'Unexpected empty `local_results_queue`' + f' (got only {proc_idx}/{len(write_buckets)} items)' + ) + break + else: + if isinstance(local_results_or_exc, Exception): + err_msg = ( + f"Local process {local_proc_idx} encountered" + f" an error: {local_results_or_exc}" + ) + logger.error(err_msg) + write_results_or_exc = local_results_or_exc + break + assert isinstance(local_results_or_exc, list), type(local_results_or_exc) + write_results_or_exc[local_proc_idx] = local_results_or_exc + p_list[local_proc_idx].join() + + logger.debug('FileSystemWriterAsync: collected worker results successfully') + + global_results_queue.put(write_results_or_exc) + + w_end = time() + logger.debug(f"{w_end}, rank: {rank}," f" write(sync,parallel): {w_end - w_start}") + + @staticmethod + @_disable_gc() + def write_preloaded_data( + local_proc_idx: int, + write_bucket: WriteBucket, + results_queue: mp.SimpleQueue, + count_queue: mp.JoinableQueue, + use_fsync: bool, + ) -> None: + """ + Performs actual data saving to storage. + + Args: + local_proc_idx (int): index of a local process that performs writing + write_bucket (WriteBucket): data to write to storage + results_queue (mp.Queue): queue to return the write results + to the proxy checkpoint process. + count_queue (mp.JoinableQueue): queue to marks worker task as completed + use_fsync (bool): if True, calls os.fsync at the end of saving + + Returns: None, the write result are put into the `queue` + """ + logger = logging.getLogger(__name__) + logger.debug(f'{local_proc_idx} started') + mem_before = _process_memory() + + local_results = [] + try: + file_name, storage_key, (bytes_data, tensor_data) = write_bucket + with open(file_name, "wb") as stream: + for write_item, data in bytes_data: + local_results.append(_write_item(stream, data, write_item, storage_key)) + + for write_item, tensor in tensor_data: + assert tensor.is_cpu + local_results.append(_write_item(stream, tensor, write_item, storage_key)) + + if use_fsync: + os.fsync(stream.fileno()) + local_output = (local_proc_idx, local_results) + except Exception as e: + logger.debug(f'{local_proc_idx} failed') + local_output = (local_proc_idx, e) + + results_queue.put(local_output) + # Signal this process is done. + count_queue.get() + count_queue.task_done() + + mem_after = _process_memory() + logger.debug( + f"{local_proc_idx} consumed: {mem_after - mem_before}," + f" before: {mem_before}, after: {mem_after}" + ) + + def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]: + """Write all items from ``plan``.""" + raise NotImplementedError('write_data not implemented for FileSystemWriterAsync') + + def retrieve_write_results(self) -> List[WriteResult]: + """ + Turn the latest dict including write results from `self.results_queue` + into a single results lists. Includes error check. + + Returns (List[WriteResult]): the list of write results + from all local processes performing the save. + + """ + assert self.write_buckets is not None + + if self.results_queue is None: + write_results_or_exc = {} + else: + try: + write_results_or_exc = self.results_queue.get_nowait() + except queue.Empty: + raise RuntimeError(f'results_queue should not be empty') + + if isinstance(write_results_or_exc, Exception): + raise RuntimeError(f'Worker failure: {write_results_or_exc}') from write_results_or_exc + write_results: dict = write_results_or_exc + if len(write_results) != len(self.write_buckets): + raise RuntimeError( + f'Incomplete worker results (expected {len(self.write_buckets)},' + f' got {len(write_results)}. This probably indicates a worker failure.' + ) + return list(chain.from_iterable(write_results.values())) + + def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: + """Instead of assigning indices by plan order, uses PyT rank (same outcome). + + Args: + local_plan (SavePlan): local plan to turn to a global plan + (without interactions with other ranks) + + Returns: + SavePlan - locally transformed plan equivalent to the plan that would be + created by the coordinator + """ + return dataclasses.replace( + local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_") + ) + + +def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]: + """ + Splits write items according to item size into close to uniform bins. + + Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type, + but with a fixed _item_size function. + + Args: + bins (int): numbers of bins to split to + items (List[WriteItem]): list of write items + + Returns (List[List[WriteItem]]): write items split to bins + """ + if bins == 1: + return [items] + + bytes_items: List[WriteItem] = [] + tensor_items: List[WriteItem] = [] + for wi in items: + container = bytes_items if wi.type == WriteItemType.BYTE_IO else tensor_items + container.append(wi) + + buckets: List[List[WriteItem]] = [[] for _ in range(bins)] + bucket_sizes = [0 for _ in range(bins)] + + # Assign bytes with a simple round-robin + for i, item in enumerate(bytes_items): + buckets[i % bins].append(item) + + # Sort tensor items by size in decreasing order once and store the size with item + sized_tensors = [(item, _item_size(item)) for item in tensor_items] + sized_tensors.sort(key=itemgetter(1), reverse=True) + + # Use a min heap for bin assignment + # Store (total_size_of_bin, bin_index) tuples + heap: List[Tuple[int, int]] = [(0, i) for i in range(bins)] + + # Assign tensors using heap + for item, size in sized_tensors: + total_bin_size, bin_idx = heappop(heap) + buckets[bin_idx].append(item) + heappush(heap, (total_bin_size + size, bin_idx)) + + return buckets + + +def _split_by_separation_hint( + buckets: List[List[WriteItem]], separation_hint: Optional[str] = None +) -> Dict[str, List[List[WriteItem]]]: + """ + Splits buckets into those whose keys begin with the separation_hint and those whose keys do not + + Args: + buckets (List[List[WriteItem]]): buckets to split + separation_hint (Optional[str]): optional prefix to split on + + Returns (Dict[str, List[List[WriteItem]]]): a dictionary + mapping the prefix to the relevant buckets + """ + bins = len(buckets) + buckets_with_separation_hint = {} + if separation_hint is not None: + buckets_default = [[] for _ in range(bins)] + buckets_hint = [[] for _ in range(bins)] + for i in range(bins): + for item in buckets[i]: + if item.index.fqn.startswith(separation_hint): + buckets_hint[i].append(item) + else: + buckets_default[i].append(item) + buckets_with_separation_hint[""] = buckets_default + buckets_with_separation_hint[separation_hint] = buckets_hint + else: + buckets_with_separation_hint[""] = buckets + return buckets_with_separation_hint + + +def _item_size(item: WriteItem) -> int: + """ + Calculates size (in bytes) of a single write item. + + Same as torch.distributed.checkpoint.filesystem._item_size, + but fixes computing chunk size (with item.tensor_data.chunk.sizes) + + Args: + item (WriteItem): write item to compute the size of + + Returns (int): size of an item in bytes + """ + size = 1 + assert item.tensor_data is not None + # can't use math.prod as PT needs to support older python + for s in item.tensor_data.chunk.sizes: + size *= s + + dtype = item.tensor_data.properties.dtype + return size * torch._utils._element_size(dtype) + + +def _process_memory() -> int: + """ + Get memory used by current process. + + Returns (int): memory used by current process + """ + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + return mem_info.rss diff --git a/megatron/core/dist_checkpointing/strategies/fully_parallel.py b/megatron/core/dist_checkpointing/strategies/fully_parallel.py index 56523daf1cf266cbe73a7b9f8182e1b0c030d741..3fb46c587490eb0520de28ede15b684bf19c8db9 100644 --- a/megatron/core/dist_checkpointing/strategies/fully_parallel.py +++ b/megatron/core/dist_checkpointing/strategies/fully_parallel.py @@ -1,439 +1,515 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import logging -from pathlib import Path -from time import time -from typing import Dict, Optional, Tuple - -import torch -import torch.distributed as dist - -from megatron.core.dist_checkpointing import ShardedTensor -from megatron.core.dist_checkpointing.core import CheckpointingException -from megatron.core.dist_checkpointing.dict_utils import ( - dict_list_map_inplace, - extract_matching_values, - merge, - nested_values, -) -from megatron.core.dist_checkpointing.exchange_utils import ( - ShardDistribution, - determine_main_replica_uniform_distribution, - exchange_by_distribution, -) -from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica -from megatron.core.dist_checkpointing.strategies.base import ( - AsyncSaveShardedStrategy, - LoadShardedStrategy, - SaveShardedStrategy, -) -from megatron.core.dist_checkpointing.utils import _sharded_tensor_shard_id, _ShardId -from megatron.core.dist_checkpointing.validation import ( - determine_global_metadata, - validate_sharding_integrity, -) - -logger = logging.getLogger(__name__) - - -class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy): - """Wraps arbitrary strategy and distributes the save during `save`. - - The save distribution happens without any *data* communication. - Only the *metadata* is exchanged and based on data replication on different - ranks, we try to distribute the save as uniformly as possible. - - This wrapper assumes, that setting `replica_id` to 0 will make the - underlying strategy do the saving on current rank. All the other `replica_id`s - are set to 1. - - Currently, the save distribution is realized with a greedy algorithm - described in `distribute_shards_to_ranks`. - - Args: - strategy (SaveShardedStrategy): base strategy to wrap - parallelization_group (ProcessGroup, optional): process group to use for save - distribution. Note that this doesn't have to match exactly the - data distribution, but should cover the replication pattern - to maximize performance. Defaults to the whole world. - do_cache_distribution (bool, optional): whether to cache the save distribution - from previous calls. Should be set to True only if the state dict - structure between the calls is always the same. Defaults to True. - """ - - def __init__( - self, - strategy: SaveShardedStrategy, - parallelization_group: Optional[torch.distributed.ProcessGroup] = None, - do_cache_distribution: bool = False, - ): - super().__init__(strategy.backend, strategy.version) - self.base_strategy = strategy - self.parallelization_group = parallelization_group - self.do_cache_distribution = do_cache_distribution - - self.cached_distribution: Optional[ShardDistribution] = None - - def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - if not isinstance(self.base_strategy, AsyncSaveShardedStrategy): - raise CheckpointingException( - f'Cannot apply async_save to non-async base strategy {self.base_strategy}' - ) - self.apply_saving_parallelization(sharded_state_dict) - return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir) - - def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - self.apply_saving_parallelization(sharded_state_dict) - return self.base_strategy.save(sharded_state_dict, checkpoint_dir) - - def apply_saving_parallelization(self, sharded_state_dict: ShardedStateDict) -> None: - """Distributes the save across ranks by exchanging metadata. - - Exchanges metadata from the state dict and computes the uniform - (as close as possible) distribution of saves among the ranks. - - If `self.do_cache_distribution` is True, caches the distribution between - the calls and subsequent distributions happen without any inter-rank - communication. - - Args: - sharded_state_dict (ShardedStateDict): state dict to distribute the saving - - Returns: None - """ - start = time() - if self.do_cache_distribution and self.cached_distribution is not None: - logger.debug(f'Apply *cached* save parallelization') - precomputed_distribution = self.cached_distribution - else: - logger.debug(f'Apply save parallelization') - precomputed_distribution = determine_main_replica_uniform_distribution( - sharded_state_dict, self.parallelization_group - ) - - distribute_main_replicas_with_precomputed_distribution( - sharded_state_dict, self.parallelization_group, precomputed_distribution - ) - if self.cached_distribution is None: - # First time applying the parallelization - validate_sharding_integrity(determine_global_metadata(sharded_state_dict)[1]) - if self.do_cache_distribution: - self.cached_distribution = precomputed_distribution - end = time() - logger.debug(f"parallel save sharding, time: {end - start}") - - @property - def can_handle_sharded_objects(self): - return self.base_strategy.can_handle_sharded_objects - - -class FullyParallelLoadStrategyWrapper(LoadShardedStrategy): - """Wraps arbitrary load strategy and distributes the load during `load`. - - See `load` method docs for details. - - Args: - strategy (LoadShardedStrategy): base strategy to wrap - parallelization_group (ProcessGroup, optional): process group to use for load - distribution. Note that this doesn't have to match exactly the - data distribution, but should cover the replication pattern - to maximize performance. Defaults to the whole world. - In most cases, it's recommended to set it to the DP group. - do_cache_distribution (bool, optional): whether to cache the load distribution - from previous calls. Should be set to True only if the state dict - structure between the calls is always the same. Defaults to False, - since the loading in general happens only once during training. - Note that the load distribution *cannot* be reused as a save distribution, - because save/load is not fully symmetrical. - exchange_algo (str): algorithm to use for exchanging the data. - Options: - - broadcast - each rank broadcasts individual tensors to others - - gather_object (default) - ranks all_gather_object the whole loaded state dicts - - gather_rounds (default) - ranks all gather individual tensors in rounds - See method docs for more details. - """ - - def __init__( - self, - strategy: LoadShardedStrategy, - parallelization_group: Optional[torch.distributed.ProcessGroup] = None, - do_cache_distribution: bool = False, - exchange_algo: str = 'broadcast', - ): - super().__init__() - self.base_strategy = strategy - if parallelization_group is None: - parallelization_group = ( - dist.GroupMember.WORLD - ) # explicit group needed for torch.distributed.get_global_rank call - self.parallelization_group = parallelization_group - self.do_cache_distribution = do_cache_distribution - self.exchange_algo = exchange_algo - - self.cached_distribution: Optional[ShardDistribution] = None - - def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: - """Distributes the load and calls underlying strategy only for parts of the state dict. - - Steps: - 1. Load metadata is exchanged between the ranks in the parallelization group. - 2. Each rank deterministically plans the load for the whole workload - so that the loads are as uniform as possible. - 3. Each ranks loads its planned shard of the checkpoint. - 4. All ranks exchange the loaded shards. - - Internode communication is involved in steps (1) (with metadata) - and (4) (with actual data). Storage interaction is involved in step (3). - - Currently, the load distribution (step 2) is realized with a greedy algorithm - described in `distribute_shards_to_ranks` (same as for saving distribution). - - Currently, the shards are all gathered between all ranks in the parallelization - group. This might not be optimal (some ranks do not need all tensors), - but it's a reasonable approximation for an optimal exchange in most scenarios. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict to load - checkpoint_dir (Path): checkpoint directory to load from - - Returns: - StateDict: loaded state dict. The state dict should be equivalent to - a state dict that would be loaded with the underlying strategy - without this wrapper. - """ - if torch.distributed.get_world_size(self.parallelization_group) <= 1: - return self.base_strategy.load(sharded_state_dict, checkpoint_dir) - - # Step 1 and 2: exchange load metadata and distribute the load - start = time() - precomputed_distribution = self.apply_loading_parallelization(sharded_state_dict) - assert ( - precomputed_distribution is not None - ), 'Expecting non-trivial distribution for non-trivial parallelization group' - end = time() - logger.debug(f'self.apply_loading_parallelization took {end - start}s') - start = end - - # Step 3: load part of the checkpoint. - # Load only sharded objects first. ShardedTensors will be loaded separately - # so that we can keep track of sharded tensors loaded by this rank - (sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = ( - self._defer_loading_sharded_tensors(sharded_state_dict) - ) - loaded_state_dict = self.base_strategy.load(sharded_state_dict, checkpoint_dir) - - end = time() - logger.debug(f'Base load of ShardedObjects took {end - start}s') - start = end - - # Load sharded tensors separately - loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir) - - end = time() - logger.debug(f'Base load of ShardedTensors took {end - start}s') - start = end - - # Step 4: exchange data between ranks - logger.debug(f'Applying parallel load with algo {self.exchange_algo}') - all_loaded_tensors = exchange_by_distribution( - loaded_tensors, - unloaded_shards, - precomputed_distribution, - self.parallelization_group, - self.exchange_algo, - ) - if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()): - missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys() - raise CheckpointingException( - f'Missing shards after fully parallel loading: {missing_shards}' - ) - - sync_start = time() - torch.cuda.synchronize() - end = time() - logger.debug(f'torch.cuda.synchronize took {end - sync_start}s') - logger.debug(f'self.exchange_loaded_tensors took {end - start}s') - - self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors) - merge(loaded_state_dict, sharded_tensors) - return loaded_state_dict - - def _defer_loading_sharded_tensors( - self, sharded_state_dict: ShardedStateDict - ) -> Tuple[ - ShardedStateDict, - ShardedStateDict, - Dict[_ShardId, ShardedTensor], - Dict[_ShardId, ShardedTensor], - ]: - """Divides state dict into parts loaded by this vs other ranks. - - ShardedTensors with main replica_id will be loaded by this rank, - others will be received by other ranks (after loading from storage). - - Args: - sharded_state_dict (ShardedStateDict): state dict with ShardedTensor - that will be divided. - - Returns: a tuple of: - - ShardedStateDict: sub-state dict only with ShardedTensors - - ShardedStateDict: sub-state dict with non-ShardedTensors - - Dict[_ShardId, ShardedTensor]: ShardedTensor are uniquely identified - by shard ids. This is a mapping from shard id to a corresponding - ShardedTensor for tensors loaded by *this* rank - - Dict[_ShardId, ShardedTensor]: mapping from shard id to a corresponding - ShardedTensor for tensors loaded by *other* ranks - """ - to_load_shards = {} - unloaded_shards = {} - - sharded_tensors, sharded_state_dict = extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, ShardedTensor) - ) - - def wrap_non_main_replicas(x): - if isinstance(x, ShardedTensor): - # Assign shard to be loaded or not - if is_main_replica(x.replica_id): - to_load_shards[_sharded_tensor_shard_id(x)] = x - else: - unloaded_shards[_sharded_tensor_shard_id(x)] = x - return x - - dict_list_map_inplace(wrap_non_main_replicas, sharded_tensors) - return sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards - - def apply_loading_parallelization( - self, sharded_state_dict: ShardedStateDict - ) -> Optional[ShardDistribution]: - """Distributes the load across ranks by exchanging metadata. - - Exchanges metadata from the state dict and computes the uniform - (as close as possible) distribution of loads among the ranks. - Marks ShardedTensors to be loaded by the current rank with replica_id 0 - (and others with non 0 values). - - If `self.do_cache_distribution` is True, caches the distribution between - the calls and subsequent distributions happen without any inter-rank - communication. - - Args: - sharded_state_dict (ShardedStateDict): state dict to distribute the loading - - Returns: - ShardDistribution (optional): the computed loading distribution - """ - if self.do_cache_distribution and self.cached_distribution is not None: - logger.debug(f'Apply *cached* load parallelization') - precomputed_distribution = self.cached_distribution - else: - logger.debug(f'Apply load parallelization') - precomputed_distribution = determine_main_replica_uniform_distribution( - sharded_state_dict, self.parallelization_group, True - ) - - distribute_main_replicas_with_precomputed_distribution( - sharded_state_dict, self.parallelization_group, precomputed_distribution - ) - if self.do_cache_distribution: - self.cached_distribution = precomputed_distribution - - return precomputed_distribution - - def fill_in_deferred_sharded_tensors( - self, sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor] - ) -> None: - """Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict to fill in. - ShardedTensors are completely replaced with corresponding torch.Tensors. - loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map - ShardedTensor from the sharded_state_dict to loaded tensors. - - Returns: - - """ - - def fill_in_sharded_tensor(x): - if isinstance(x, ShardedTensor): - try: - x = loaded_tensors[_sharded_tensor_shard_id(x)] - except KeyError as e: - raise CheckpointingException( - f'Missing loaded tensor shard: {_sharded_tensor_shard_id(x)}' - ) from e - - return x - - dict_list_map_inplace(fill_in_sharded_tensor, sharded_state_dict) - - @property - def can_handle_sharded_objects(self): - return self.base_strategy.can_handle_sharded_objects - - def load_tensors_metadata(self, checkpoint_dir: Path): - return self.base_strategy.load_tensors_metadata(checkpoint_dir) - - def load_sharded_metadata(self, checkpoint_dir: Path): - return self.base_strategy.load_sharded_metadata(checkpoint_dir) - - def check_backend_compatibility(self, loaded_version): - return self.base_strategy.check_backend_compatibility(loaded_version) - - def check_version_compatibility(self, loaded_version): - return self.base_strategy.check_version_compatibility(loaded_version) - - -def distribute_main_replicas_with_precomputed_distribution( - sharded_state_dict: ShardedStateDict, - parallelization_group: torch.distributed.ProcessGroup, - precomputed_distribution: Optional[ShardDistribution], -): - """Applies the save distribution computed with `determine_main_replica_uniform_distribution`. - - Based on rank assignment, sets replica ids of the shards saved by current rank to 0 - and all the other replica ids to 1. - - Args: - sharded_state_dict (ShardedStateDict): state dict to apply the save distribution to - parallelization_group (ProcessGroup): distribution will be applied within this - process group. Must match with the process group passed to - `determine_main_replica_uniform_distribution`. - precomputed_distribution (ShardDistribution): distribution computed with - `determine_main_replica_uniform_distribution` - - Returns: None - - Example replica ids of tensors A, B, C before distribution: - rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0) - rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1) - rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2) - - Replicas after distribution for the example above: - rank0: A: 0, B: 1, C: 1 - rank1: A: 1, B: 0, C: 1 - rank2: A: 1, B: 1, C: 0 - """ - if torch.distributed.get_world_size(group=parallelization_group) <= 1: - return - if precomputed_distribution is None: - raise ValueError( - 'precomputed_distribution must be not None for non-trivial parallelization group' - ) - - local_shards = list( - sh_base - for sh_base in nested_values(sharded_state_dict) - if isinstance(sh_base, ShardedTensor) - ) - - rank_within_dp_group = torch.distributed.get_rank(parallelization_group) - for sh_ten in local_shards: - shard_id = _sharded_tensor_shard_id(sh_ten) - if ( - shard_id in precomputed_distribution.shards_in_this_group - and rank_within_dp_group == precomputed_distribution.main_rank_for_shard[shard_id] - ): - sh_ten.replica_id = 0 - else: - sh_ten.replica_id = 1 +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import logging +from pathlib import Path +from time import time +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint import Metadata + +from megatron.core.dist_checkpointing import ShardedObject, ShardedTensor +from megatron.core.dist_checkpointing.core import CheckpointingException +from megatron.core.dist_checkpointing.dict_utils import ( + dict_list_map_inplace, + extract_matching_values, + merge, + nested_values, +) +from megatron.core.dist_checkpointing.exchange_utils import ( + ShardDistribution, + determine_main_replica_uniform_distribution, + exchange_by_distribution, + exchange_loaded_objects_gather_object, +) +from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica +from megatron.core.dist_checkpointing.strategies.base import ( + AsyncSaveShardedStrategy, + LoadShardedStrategy, + SaveShardedStrategy, +) +from megatron.core.dist_checkpointing.utils import ( + _sharded_object_id, + _sharded_tensor_shard_id, + _ShardId, + debug_time, +) +from megatron.core.dist_checkpointing.validation import ( + determine_global_metadata, + validate_sharding_integrity, +) + +logger = logging.getLogger(__name__) + +T = TypeVar('T', ShardedObject, ShardedTensor) + + +class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy): + """Wraps arbitrary strategy and distributes the save during `save`. + + The save distribution happens without any *data* communication. + Only the *metadata* is exchanged and based on data replication on different + ranks, we try to distribute the save as uniformly as possible. + + This wrapper assumes, that setting `replica_id` to 0 will make the + underlying strategy do the saving on current rank. All the other `replica_id`s + are set to 1. + + Currently, the save distribution is realized with a greedy algorithm + described in `distribute_shards_to_ranks`. + + Args: + strategy (SaveShardedStrategy): base strategy to wrap + parallelization_group (ProcessGroup, optional): process group to use for save + distribution. Note that this doesn't have to match exactly the + data distribution, but should cover the replication pattern + to maximize performance. Defaults to the whole world. + do_cache_distribution (bool, optional): whether to cache the save distribution + from previous calls. Should be set to True only if the state dict + structure between the calls is always the same. Defaults to True. + """ + + def __init__( + self, + strategy: SaveShardedStrategy, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + do_cache_distribution: bool = False, + ): + super().__init__(strategy.backend, strategy.version) + self.base_strategy = strategy + self.parallelization_group = parallelization_group + self.do_cache_distribution = do_cache_distribution + + self.cached_distribution: Optional[ShardDistribution] = None + + def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + if not isinstance(self.base_strategy, AsyncSaveShardedStrategy): + raise CheckpointingException( + f'Cannot apply async_save to non-async base strategy {self.base_strategy}' + ) + self.apply_saving_parallelization(sharded_state_dict) + return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir) + + def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + self.apply_saving_parallelization(sharded_state_dict) + return self.base_strategy.save(sharded_state_dict, checkpoint_dir) + + def apply_saving_parallelization(self, sharded_state_dict: ShardedStateDict) -> None: + """Distributes the save across ranks by exchanging metadata. + + Exchanges metadata from the state dict and computes the uniform + (as close as possible) distribution of saves among the ranks. + + If `self.do_cache_distribution` is True, caches the distribution between + the calls and subsequent distributions happen without any inter-rank + communication. + + Args: + sharded_state_dict (ShardedStateDict): state dict to distribute the saving + + Returns: None + """ + start = time() + if self.do_cache_distribution and self.cached_distribution is not None: + logger.debug(f'Apply *cached* save parallelization') + precomputed_distribution = self.cached_distribution + else: + logger.debug(f'Apply save parallelization') + precomputed_distribution = determine_main_replica_uniform_distribution( + sharded_state_dict, self.parallelization_group + ) + + distribute_main_replicas_with_precomputed_distribution( + sharded_state_dict, self.parallelization_group, precomputed_distribution + ) + if self.cached_distribution is None: + # First time applying the parallelization + validate_sharding_integrity(determine_global_metadata(sharded_state_dict)[1]) + if self.do_cache_distribution: + self.cached_distribution = precomputed_distribution + end = time() + logger.debug(f"parallel save sharding, time: {end - start}") + + @property + def can_handle_sharded_objects(self): + return self.base_strategy.can_handle_sharded_objects + + +class FullyParallelLoadStrategyWrapper(LoadShardedStrategy): + """Wraps arbitrary load strategy and distributes the load during `load`. + + See `load` method docs for details. + + Args: + strategy (LoadShardedStrategy): base strategy to wrap + parallelization_group (ProcessGroup, optional): process group to use for load + distribution. Note that this doesn't have to match exactly the + data distribution, but should cover the replication pattern + to maximize performance. Defaults to the whole world. + In most cases, it's recommended to set it to the DP group. + do_cache_distribution (bool, optional): whether to cache the load distribution + from previous calls. Should be set to True only if the state dict + structure between the calls is always the same. Defaults to False, + since the loading in general happens only once during training. + Note that the load distribution *cannot* be reused as a save distribution, + because save/load is not fully symmetrical. + exchange_algo (str): algorithm to use for exchanging the data. + Options: + - broadcast - each rank broadcasts individual tensors to others + - gather_object (default) - ranks all_gather_object the whole loaded state dicts + - gather_rounds (default) - ranks all gather individual tensors in rounds + See method docs for more details. + """ + + def __init__( + self, + strategy: LoadShardedStrategy, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + do_cache_distribution: bool = False, + exchange_algo: str = 'broadcast', + ): + super().__init__() + self.base_strategy = strategy + if parallelization_group is None: + parallelization_group = ( + dist.GroupMember.WORLD + ) # explicit group needed for torch.distributed.get_global_rank call + self.parallelization_group = parallelization_group + self.do_cache_distribution = do_cache_distribution + self.exchange_algo = exchange_algo + + self.cached_distribution: Optional[ShardDistribution] = None + self.cached_global_metadata: Optional[Metadata] = None + + @debug_time("FullyParallelLoadStrategyWrapper.load", logger) + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: + """Distributes the load and calls underlying strategy only for parts of the state dict. + + Steps: + 1. Load metadata is exchanged between the ranks in the parallelization group. + 2. Each rank deterministically plans the load for the whole workload + so that the loads are as uniform as possible. + 3. Each ranks loads its planned shard of the checkpoint. + 4. All ranks exchange the loaded shards. + + Internode communication is involved in steps (1) (with metadata) + and (4) (with actual data). Storage interaction is involved in step (3). + + Currently, the load distribution (step 2) is realized with a greedy algorithm + described in `distribute_shards_to_ranks` (same as for saving distribution). + + Currently, the shards are all gathered between all ranks in the parallelization + group. This might not be optimal (some ranks do not need all tensors), + but it's a reasonable approximation for an optimal exchange in most scenarios. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to load + checkpoint_dir (Path): checkpoint directory to load from + + Returns: + StateDict: loaded state dict. The state dict should be equivalent to + a state dict that would be loaded with the underlying strategy + without this wrapper. + """ + + loaded_state_dict = {} + + if torch.distributed.get_world_size(self.parallelization_group) <= 1: + return self.base_strategy.load(sharded_state_dict, checkpoint_dir) + + # Step 1 and 2: exchange load metadata and distribute the load + with debug_time("self.apply_loading_parallelization", logger): + precomputed_distribution: ShardDistribution | None = self.apply_loading_parallelization( + sharded_state_dict + ) + assert ( + precomputed_distribution is not None + ), 'Expecting non-trivial distribution for non-trivial parallelization group' + + # Step 3: load part of the checkpoint. + # Load only sharded objects first. ShardedTensors will be loaded separately + # so that we can keep track of sharded tensors loaded by this rank + (sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = ( + self._defer_loading_sharded_tensors(sharded_state_dict) + ) + + (sharded_objects, sharded_state_dict, to_load_objects, unloaded_objects) = ( + self._defer_loading_sharded_objects(sharded_state_dict) + ) + + assert ( + len(sharded_state_dict) == 0 + ), "sharded_state_dict is not empty after deferring tensors and objects" + with debug_time("base_load_ShardedObjects", logger): + # Load sharded objects first + loaded_objects = self.base_strategy.load(to_load_objects, checkpoint_dir) + + with debug_time("base_load_ShardedTensors", logger): + # Load sharded tensors separately + loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir) + + with debug_time("self.exchange_loaded_tensors", logger): + + # Step 4: exchange data between ranks + logger.debug(f'Applying parallel load with algo {self.exchange_algo}') + all_loaded_tensors = exchange_by_distribution( + loaded_tensors, + unloaded_shards, + precomputed_distribution, + self.parallelization_group, + self.exchange_algo, + ) + if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()): + missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys() + raise CheckpointingException( + f'Missing shards after fully parallel loading: {missing_shards}' + ) + + with debug_time("torch.cuda.synchronize", logger): + torch.cuda.synchronize() + + all_loaded_objects = exchange_loaded_objects_gather_object(loaded_objects) + + if not set(unloaded_objects.keys()).issubset(all_loaded_objects.keys()): + missing_object_shards = set(unloaded_objects.keys()) - all_loaded_objects.keys() + raise CheckpointingException( + f'Missing object shards after fully parallel loading: {missing_object_shards}' + ) + torch.cuda.synchronize() + + self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors) + self.fill_in_deferred_sharded_objects(sharded_objects, all_loaded_objects) + + merge(loaded_state_dict, sharded_objects) + merge(loaded_state_dict, sharded_tensors) + if hasattr(self.base_strategy, "cached_global_metadata"): + self.cached_global_metadata = self.base_strategy.cached_global_metadata + return loaded_state_dict + + @staticmethod + def _defer_loading_sharded_objects( + sharded_state_dict: ShardedStateDict, + ) -> Tuple[ + ShardedStateDict, + ShardedStateDict, + Dict[_ShardId, ShardedObject], + Dict[_ShardId, ShardedObject], + ]: + return _defer_loading_sharded_items(sharded_state_dict, ShardedObject, _sharded_object_id) + + @staticmethod + def _defer_loading_sharded_tensors( + sharded_state_dict: ShardedStateDict, + ) -> Tuple[ + ShardedStateDict, + ShardedStateDict, + Dict[_ShardId, ShardedTensor], + Dict[_ShardId, ShardedTensor], + ]: + return _defer_loading_sharded_items( + sharded_state_dict, ShardedTensor, _sharded_tensor_shard_id + ) + + @staticmethod + def fill_in_deferred_sharded_objects( + sharded_state_dict: ShardedStateDict, loaded_objects: Dict[_ShardId, Any] + ) -> None: + """Fill in objects not loaded by current rank with objects from `loaded_objects` map. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to fill in. + ShardedObjects are completely replaced with corresponding objects. + loaded_objects (Dict[_ShardId, Any]): dict allowing to map + ShardedObject from the sharded_state_dict to loaded objects. + + Returns: + None + """ + _fill_in_deferred_sharded_items( + sharded_state_dict, loaded_objects, ShardedObject, _sharded_object_id + ) + + @staticmethod + def fill_in_deferred_sharded_tensors( + sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor] + ) -> None: + """Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to fill in. + ShardedTensors are completely replaced with corresponding torch.Tensors. + loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map + ShardedTensor from the sharded_state_dict to loaded tensors. + + Returns: + None + """ + _fill_in_deferred_sharded_items( + sharded_state_dict, loaded_tensors, ShardedTensor, _sharded_tensor_shard_id + ) + + def apply_loading_parallelization( + self, sharded_state_dict: ShardedStateDict + ) -> Optional[ShardDistribution]: + """Distributes the load across ranks by exchanging metadata. + + Exchanges metadata from the state dict and computes the uniform + (as close as possible) distribution of loads among the ranks. + Marks ShardedTensors to be loaded by the current rank with replica_id 0 + (and others with non 0 values). + + If `self.do_cache_distribution` is True, caches the distribution between + the calls and subsequent distributions happen without any inter-rank + communication. + + Args: + sharded_state_dict (ShardedStateDict): state dict to distribute the loading + + Returns: + ShardDistribution (optional): the computed loading distribution + """ + if self.do_cache_distribution and self.cached_distribution is not None: + logger.debug(f'Apply *cached* load parallelization') + precomputed_distribution = self.cached_distribution + else: + logger.debug(f'Apply load parallelization') + precomputed_distribution = determine_main_replica_uniform_distribution( + sharded_state_dict, self.parallelization_group, True + ) + + distribute_main_replicas_with_precomputed_distribution( + sharded_state_dict, self.parallelization_group, precomputed_distribution + ) + if self.do_cache_distribution: + self.cached_distribution = precomputed_distribution + + return precomputed_distribution + + @property + def can_handle_sharded_objects(self): + return self.base_strategy.can_handle_sharded_objects + + def load_tensors_metadata(self, checkpoint_dir: Path): + return self.base_strategy.load_tensors_metadata(checkpoint_dir) + + def load_sharded_metadata(self, checkpoint_dir: Path): + return self.base_strategy.load_sharded_metadata(checkpoint_dir) + + def check_backend_compatibility(self, loaded_version): + return self.base_strategy.check_backend_compatibility(loaded_version) + + def check_version_compatibility(self, loaded_version): + return self.base_strategy.check_version_compatibility(loaded_version) + + +def distribute_main_replicas_with_precomputed_distribution( + sharded_state_dict: ShardedStateDict, + parallelization_group: torch.distributed.ProcessGroup, + precomputed_distribution: Optional[ShardDistribution], +): + """Applies the save distribution computed with `determine_main_replica_uniform_distribution`. + + Based on rank assignment, sets replica ids of the shards saved by current rank to 0 + and all the other replica ids to 1. + + Args: + sharded_state_dict (ShardedStateDict): state dict to apply the save distribution to + parallelization_group (ProcessGroup): distribution will be applied within this + process group. Must match with the process group passed to + `determine_main_replica_uniform_distribution`. + precomputed_distribution (ShardDistribution): distribution computed with + `determine_main_replica_uniform_distribution` + + Returns: None + + Example replica ids of tensors A, B, C before distribution: + rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0) + rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1) + rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2) + + Replicas after distribution for the example above: + rank0: A: 0, B: 1, C: 1 + rank1: A: 1, B: 0, C: 1 + rank2: A: 1, B: 1, C: 0 + """ + if torch.distributed.get_world_size(group=parallelization_group) <= 1: + return + if precomputed_distribution is None: + raise ValueError( + 'precomputed_distribution must be not None for non-trivial parallelization group' + ) + + local_shards = list( + sh_base + for sh_base in nested_values(sharded_state_dict) + if isinstance(sh_base, ShardedTensor) + ) + + rank_within_dp_group = torch.distributed.get_rank(parallelization_group) + for sh_ten in local_shards: + shard_id = _sharded_tensor_shard_id(sh_ten) + if ( + shard_id in precomputed_distribution.shards_in_this_group + and rank_within_dp_group == precomputed_distribution.main_rank_for_shard[shard_id] + ): + sh_ten.replica_id = 0 + else: + sh_ten.replica_id = 1 + + +def _defer_loading_sharded_items( + sharded_state_dict: ShardedStateDict, item_type: type, shard_id_func: Callable[[T], _ShardId] +) -> Tuple[ShardedStateDict, ShardedStateDict, Dict[_ShardId, T], Dict[_ShardId, T]]: + """Divides state dict into parts loaded by this vs other ranks. + + Args: + sharded_state_dict (ShardedStateDict): state dict with sharded items + that will be divided. + item_type: The type of sharded item (ShardedObject or ShardedTensor) + shard_id_func: Function to get the shard ID for the item type + + Returns: a tuple of: + - ShardedStateDict: sub-state dict only with sharded items + - ShardedStateDict: sub-state dict with non-sharded items + - Dict[_ShardId, T]: mapping from shard id to items loaded by *this* rank + - Dict[_ShardId, T]: mapping from shard id to items loaded by *other* ranks + """ + to_load_shards = {} + unloaded_shards = {} + + sharded_items, remaining_state_dict = extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, item_type) + ) + + def wrap_non_main_replicas(x: Any) -> Any: + if isinstance(x, item_type): + shard_id = shard_id_func(x) + if is_main_replica(x.replica_id): + to_load_shards[shard_id] = x + else: + unloaded_shards[shard_id] = x + return x + + dict_list_map_inplace(wrap_non_main_replicas, sharded_items) + return sharded_items, remaining_state_dict, to_load_shards, unloaded_shards + + +def _fill_in_deferred_sharded_items( + sharded_state_dict: ShardedStateDict, + loaded_items: Dict[_ShardId, Any], + item_type: type, + shard_id_func: Callable[[T], _ShardId], +) -> None: + """Helper function to fill in items not loaded by current rank.""" + + def fill_in_sharded_item(x: Any) -> Any: + if isinstance(x, item_type): + try: + x = loaded_items[shard_id_func(x)] + except KeyError as e: + raise CheckpointingException( + f'Missing loaded item shard: {shard_id_func(x)}' + ) from e + return x + + dict_list_map_inplace(fill_in_sharded_item, sharded_state_dict) diff --git a/megatron/core/dist_checkpointing/strategies/resharding.py b/megatron/core/dist_checkpointing/strategies/resharding.py index c1c2bcec8499709ef1e072240b662543bd8c8fc1..6b5aeb2f2d98e65d1d5b662a2b883535ed5bc2b8 100644 --- a/megatron/core/dist_checkpointing/strategies/resharding.py +++ b/megatron/core/dist_checkpointing/strategies/resharding.py @@ -1,315 +1,318 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -""" Performant resharding of flattened tensors. - -Tensors that are first sharded (e.g. across TP) and then flattened cause -very irregular access patterns during loading. The idea for performant save/load -is to store tensors with global shape [X, Y, Z] and local shape [x, y, z] -as tensors with global shape [X // x, Y // y, Z // z, x * y * z] and -local shape [1, 1, 1, x * y * z]. This allows parallel save of tensors along the -last (flattened) dimension. During loading, some additional resharding is needed. -""" -import logging -import math -from dataclasses import dataclass -from itertools import product -from typing import Any, Dict, Optional, Tuple, Union - -import numpy as np -import torch -from torch.distributed.checkpoint import ChunkStorageMetadata -from torch.distributed.checkpoint.resharding import _shards_get_overlap_region_wrt_saved_tensor - -from megatron.core.dist_checkpointing import ShardedTensor -from megatron.core.dist_checkpointing.core import CheckpointingException -from megatron.core.dist_checkpointing.dict_utils import ( - dict_list_map_inplace, - extract_matching_values, -) -from megatron.core.dist_checkpointing.mapping import ( - ReplicaId, - ShardedStateDict, - ShardedTensorFactory, - StateDict, - apply_factories, - apply_factory_merges, -) - -logger = logging.getLogger(__name__) - - -@dataclass -class TensorReformulationMetadata: - """Metadata needed to restore the original tensor shape. - - Args: - ckpt_orig_global_shape (Tuple[int, ...]): original global shape of the tensor - saved in the checkpoint. This is the global shape of the application, - further reformulated into `ckpt_reform_global_shape` while saving. - ckpt_reform_global_shape (Tuple[int, ...]): reformulated global shape of the tensor - saved in the checkpoint. This is the actual saved shape. - """ - - ckpt_orig_global_shape: Tuple[int, ...] - ckpt_reform_global_shape: Tuple[int, ...] - - def __post_init__(self): - assert self.ckpt_orig_global_shape - - -def nd_flattened_tensor_reformulated_global_shape(sh_ten: ShardedTensor) -> Tuple[int, ...]: - """Reformulated global shape of the flattened N-D ShardedTensor. - - N-D tensor global shape [X, Y, Z] and local shape [x, y, z] - is reformulated into global shape [X // x, Y // y, Z // z, x * y * z] and - local shape [1, 1, 1, x * y * z], to allow parallel save of tensors along the - last (flattened) dimension. - - Args: - sh_ten (ShardedTensor): flattened N-D ShardedTensor (N > 1) - - Returns: - Tuple[int, ...]: reformulated tensor shape - """ - assert is_nd_flattened_tensor(sh_ten), sh_ten - return sh_ten.axis_fragmentations + (int(np.prod(sh_ten.local_shape)),) - - -def is_nd_flattened_tensor(sh_ten: Any) -> bool: - """Checks if ShardedTensor is flattened and more than 1-dimensional - - Args: - sh_ten (Any): any object - - Returns: - bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1) - """ - return ( - isinstance(sh_ten, ShardedTensor) - and sh_ten.flattened_range is not None - and len(sh_ten.global_shape) > 1 - ) - - -# information needed to restore. With current implementation, this is a nested state dict -# with ShardedTensorFactories which is basically a ShardedStateDict type -ReformulationRestoreMetadata = ShardedStateDict - - -def apply_nd_flattened_tensors_reformulation( - sharded_state_dict: ShardedStateDict, - reformulation_metadata: Dict[str, TensorReformulationMetadata], -) -> Tuple[ShardedStateDict, ReformulationRestoreMetadata]: - """Applies N-D reformulation to a given sharded state dict. - - After applying the method and loading the reformulated state dict, - the `restore_nd_flattened_tensors_formulation` needs to be applied. - - Current implementation uses ShardedTensorFactories for convenience of - restoring the original structure, but it's just an implementation detail. - Turns N-D ShardedTensors into factories and immediately applies them, - keeping the data needed to restore the original structure. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict potentially - with tensors to reformulate. - reformulation_metadata (Dict[str, TensorReformulationMetadata]): dict - containing all metadata needed for reformulating tensors in `sharded_state_dict`. - for each N-D flattened tensor `sh_ten` in `sharded_state_dict` there must be an - entry with `sh_ten.key`. - - Returns: - tuple: - ShardedStateDict - reformulated sharded state dict - ReformulationRestoreMetadata - data needed to restore the original formulation - with `restore_nd_flattened_tensors_formulation` - """ - - def maybe_reformulate_nd_flattened_tensor(sh_ten: Any): - if not isinstance(sh_ten, ShardedTensor) or not is_nd_flattened_tensor(sh_ten): - return sh_ten - # N-D flattened ShardedTensor - try: - sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key] - except KeyError as e: - raise CheckpointingException( - f'Missing reformulation metadata for tensor {sh_ten}. Existing keys: {reformulation_metadata.keys()}' - ) from e - - ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape - app_actual_load_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten) - if ckpt_actual_saved_shape == app_actual_load_shape: - # Same shape - no need to reshard - return sh_ten - - return reformulate_single_nd_flattened_tensor(sh_ten, sh_ten_reformulation_metadata) - - # Turn N-D tensors into factories and immediately apply them - dict_list_map_inplace(maybe_reformulate_nd_flattened_tensor, sharded_state_dict) - sh_ten_factories, _ = extract_matching_values( - sharded_state_dict, - lambda x: isinstance(x, ShardedTensorFactory), - return_lists_as_dicts=True, - ) - apply_factories(sharded_state_dict) - - # Unlink `data` pointers to free memory - def unlink_data(x): - x.data = None - return x - - dict_list_map_inplace(unlink_data, sh_ten_factories) - return sharded_state_dict, sh_ten_factories - - -def restore_nd_flattened_tensors_formulation( - state_dict: StateDict, formulation_restore_metadata: ReformulationRestoreMetadata -) -> StateDict: - """Restores the original state dict from a reformulated form. - - Inverse of `apply_nd_flattened_tensors_reformulation`. - - Args: - state_dict (StateDict): state dict obtained by loading a reformulated - sharded state dict. - formulation_restore_metadata (ReformulationRestoreMetadata): metadata returned by - `apply_nd_flattened_tensors_reformulation` function - - Returns: - StateDict: state dict with the original tensors formulation restored - """ - return apply_factory_merges(state_dict, formulation_restore_metadata) - - -def reformulate_single_nd_flattened_tensor( - sh_ten: ShardedTensor, reformulation_metadata: TensorReformulationMetadata -) -> Union[Any, ShardedTensorFactory]: - """Reformulates shapes of a single N-D flattened ShardedTensor. - - We need to define a pair of transformations: - - turn N-D ShardedTensor with original formulation into multiple reformulated ShardedTensors - - merge multiple reformulated loaded torch.Tensors into a single original tensor - Current implementation uses ShardedTensorFactories as a convenient mechanism - for specifying and keeping track of those transformations. - - Args: - sh_ten (ShardedTensor): sharded tensor to reformulate. - reformulation_metadata (TensorReformulationMetadata): metadata needed to - perform the reformulation - - Returns: - ShardedTensorFactory: factory that keeps information how to reformulate - (build) the ShardedTensor and then restore original formulation (merge) - after loading. - """ - rmd = reformulation_metadata - # Data won't be needed - remove unnecessary tensor references - sh_ten = sh_ten.without_data() - - # Based on reformulation_metadata, determine other tensor shapes and metadata - ckpt_axis_fragmentation = rmd.ckpt_reform_global_shape[:-1] - for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation): - assert sh % fragm == 0, (sh_ten, rmd.ckpt_reform_global_shape) - ckpt_local_shape_with_prepended_axis = tuple( - sh // fragm for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation) - ) - assert ( - ckpt_local_shape_with_prepended_axis[: sh_ten.prepend_axis_num] - == (1,) * sh_ten.prepend_axis_num - ), (ckpt_local_shape_with_prepended_axis, sh_ten) - ckpt_local_shape = ckpt_local_shape_with_prepended_axis[sh_ten.prepend_axis_num :] - - # Iterate over reformulated shapes needed by the application and from checkpoint, - # and generate new ShardedTensors that match the checkpoint sharding. - overlap_dim_offsets = [] - assert len(ckpt_axis_fragmentation) == len(sh_ten.axis_fragmentations), ( - ckpt_axis_fragmentation, - sh_ten, - ) - for dim, (app_chunk_dim_offset, ckpt_fragm, app_fragm) in enumerate( - zip( - sh_ten.local_chunk_offset_in_global(), - ckpt_axis_fragmentation, - sh_ten.axis_fragmentations, - ) - ): - # without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units - first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset) - # `math.ceil` argument is an exact offset of the app next shard expressed in ckpt_local_shape units - next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1)) - overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset)) - - logger.debug( - f'Generated the following number of overlap shards for each dimension: {list(map(len, overlap_dim_offsets))}' - f' for fragmentation ckpt {ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} and chunk offset {sh_ten.local_chunk_offset_in_global()}' - ) - reformulated_sh_tens = {} - for chunk_offset in product(*overlap_dim_offsets): - global_offset = tuple( - chunk_off * chunk_shape - for chunk_off, chunk_shape in zip(chunk_offset, ckpt_local_shape_with_prepended_axis) - ) - reformulated_sh_tens[(global_offset, ckpt_local_shape)] = ShardedTensor( - sh_ten.key, - None, - sh_ten.dtype, - ckpt_local_shape, - rmd.ckpt_orig_global_shape, - global_offset, - ckpt_axis_fragmentation, - sh_ten.replica_id, - sh_ten.prepend_axis_num, - sh_ten.allow_shape_mismatch, - flattened_range=slice(0, rmd.ckpt_reform_global_shape[-1]), # whole ckpt shard - ) - - # Now, we have to define the transformations from application sharding - # to checkpoint sharding. - - @torch.no_grad() - def sh_ten_build_fn(*args, **kwargs): - # Here we simply return the precomputed tensors. - return reformulated_sh_tens - - @torch.no_grad() - def sh_ten_merge_fn(sub_state_dict): - # This is the non-flattened local tensor with original formulation - # that we are going to fill with shards loaded from the checkpoint. - app_non_flat_ten = torch.empty( - sh_ten.local_shape, - dtype=sh_ten.dtype, - device=sh_ten.data.device if sh_ten.data is not None else None, - ) - - assert len(sub_state_dict) > 0 - for (ckpt_global_offset, ckpt_local_shape), ckpt_ten in sub_state_dict.items(): - # For each ckpt shard, we fill the appropriate application shard part - dest_ten = app_non_flat_ten - src_ten = ckpt_ten.view(ckpt_local_shape) - # We don't need narrowing over `prepend_axis_num` axes so we take the [sh_ten.prepend_axis_num:] offsets slice - for ( - dim, - offset_for_saved_tensor, - offset_for_current_tensor, - length, - ) in _shards_get_overlap_region_wrt_saved_tensor( - saved_shard=ChunkStorageMetadata( - ckpt_global_offset[sh_ten.prepend_axis_num :], ckpt_local_shape - ), - current_shard=ChunkStorageMetadata( - sh_ten.global_offset[sh_ten.prepend_axis_num :], sh_ten.local_shape - ), - ): - src_ten = src_ten.narrow(dim, offset_for_saved_tensor, length) - dest_ten = dest_ten.narrow(dim, offset_for_current_tensor, length) - dest_ten.copy_(src_ten) - return app_non_flat_ten.flatten()[sh_ten.flattened_range] - - return ShardedTensorFactory( - sh_ten.key, - sh_ten.data, - sh_ten_build_fn, - sh_ten_merge_fn, - sh_ten.replica_id, - sh_ten.flattened_range, - ) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" Performant resharding of flattened tensors. + +Tensors that are first sharded (e.g. across TP) and then flattened cause +very irregular access patterns during loading. The idea for performant save/load +is to store tensors with global shape [X, Y, Z] and local shape [x, y, z] +as tensors with global shape [X // x, Y // y, Z // z, x * y * z] and +local shape [1, 1, 1, x * y * z]. This allows parallel save of tensors along the +last (flattened) dimension. During loading, some additional resharding is needed. +""" +import logging +import math +from dataclasses import dataclass +from itertools import product +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch +from torch.distributed.checkpoint import ChunkStorageMetadata +from torch.distributed.checkpoint.resharding import _shards_get_overlap_region_wrt_saved_tensor + +from megatron.core.dist_checkpointing import ShardedTensor +from megatron.core.dist_checkpointing.core import CheckpointingException +from megatron.core.dist_checkpointing.dict_utils import ( + dict_list_map_inplace, + extract_matching_values, +) +from megatron.core.dist_checkpointing.mapping import ( + ShardedStateDict, + ShardedTensorFactory, + StateDict, + apply_factories, + apply_factory_merges, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class TensorReformulationMetadata: + """Metadata needed to restore the original tensor shape. + + Args: + ckpt_orig_global_shape (Tuple[int, ...]): original global shape of the tensor + saved in the checkpoint. This is the global shape of the application, + further reformulated into `ckpt_reform_global_shape` while saving. + ckpt_reform_global_shape (Tuple[int, ...]): reformulated global shape of the tensor + saved in the checkpoint. This is the actual saved shape. + """ + + ckpt_orig_global_shape: Tuple[int, ...] + ckpt_reform_global_shape: Tuple[int, ...] + + def __post_init__(self): + assert self.ckpt_orig_global_shape + + +def nd_flattened_tensor_reformulated_global_shape(sh_ten: ShardedTensor) -> Tuple[int, ...]: + """Reformulated global shape of the flattened N-D ShardedTensor. + + N-D tensor global shape [X, Y, Z] and local shape [x, y, z] + is reformulated into global shape [X // x, Y // y, Z // z, x * y * z] and + local shape [1, 1, 1, x * y * z], to allow parallel save of tensors along the + last (flattened) dimension. + + Args: + sh_ten (ShardedTensor): flattened N-D ShardedTensor (N > 1) + + Returns: + Tuple[int, ...]: reformulated tensor shape + """ + assert is_nd_flattened_tensor(sh_ten), sh_ten + return sh_ten.axis_fragmentations + (int(np.prod(sh_ten.local_shape)),) + + +def is_nd_flattened_tensor(sh_ten: Any) -> bool: + """Checks if ShardedTensor is flattened and more than 1-dimensional + + Args: + sh_ten (Any): any object + + Returns: + bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1) + """ + return isinstance(sh_ten, ShardedTensor) and sh_ten.flattened_range is not None + + +# information needed to restore. With current implementation, this is a nested state dict +# with ShardedTensorFactories which is basically a ShardedStateDict type +ReformulationRestoreMetadata = ShardedStateDict + + +def apply_nd_flattened_tensors_reformulation( + sharded_state_dict: ShardedStateDict, + reformulation_metadata: Dict[str, TensorReformulationMetadata], +) -> Tuple[ShardedStateDict, ReformulationRestoreMetadata]: + """Applies N-D reformulation to a given sharded state dict. + + After applying the method and loading the reformulated state dict, + the `restore_nd_flattened_tensors_formulation` needs to be applied. + + Current implementation uses ShardedTensorFactories for convenience of + restoring the original structure, but it's just an implementation detail. + Turns N-D ShardedTensors into factories and immediately applies them, + keeping the data needed to restore the original structure. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict potentially + with tensors to reformulate. + reformulation_metadata (Dict[str, TensorReformulationMetadata]): dict + containing all metadata needed for reformulating tensors in `sharded_state_dict`. + for each N-D flattened tensor `sh_ten` in `sharded_state_dict` there must be an + entry with `sh_ten.key`. + + Returns: + tuple: + ShardedStateDict - reformulated sharded state dict + ReformulationRestoreMetadata - data needed to restore the original formulation + with `restore_nd_flattened_tensors_formulation` + """ + + def maybe_reformulate_nd_flattened_tensor(sh_ten: Any): + if not isinstance(sh_ten, ShardedTensor) or not is_nd_flattened_tensor(sh_ten): + return sh_ten + # N-D flattened ShardedTensor + try: + sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key] + except KeyError as e: + # Handle legacy checkpointing where 1-D flatten tensor metadata was not saved + if len(sh_ten.global_shape) == 1: + return sh_ten + raise CheckpointingException( + f'Missing reformulation metadata for tensor {sh_ten}. ' + f'Existing keys: {reformulation_metadata.keys()}' + ) from e + + ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape + app_actual_load_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten) + if ckpt_actual_saved_shape == app_actual_load_shape: + # Same shape - no need to reshard + return sh_ten + + return reformulate_single_nd_flattened_tensor(sh_ten, sh_ten_reformulation_metadata) + + # Turn N-D tensors into factories and immediately apply them + dict_list_map_inplace(maybe_reformulate_nd_flattened_tensor, sharded_state_dict) + sh_ten_factories, _ = extract_matching_values( + sharded_state_dict, + lambda x: isinstance(x, ShardedTensorFactory), + return_lists_as_dicts=True, + ) + apply_factories(sharded_state_dict) + + # Unlink `data` pointers to free memory + def unlink_data(x): + x.data = None + return x + + dict_list_map_inplace(unlink_data, sh_ten_factories) + return sharded_state_dict, sh_ten_factories + + +def restore_nd_flattened_tensors_formulation( + state_dict: StateDict, formulation_restore_metadata: ReformulationRestoreMetadata +) -> StateDict: + """Restores the original state dict from a reformulated form. + + Inverse of `apply_nd_flattened_tensors_reformulation`. + + Args: + state_dict (StateDict): state dict obtained by loading a reformulated + sharded state dict. + formulation_restore_metadata (ReformulationRestoreMetadata): metadata returned by + `apply_nd_flattened_tensors_reformulation` function + + Returns: + StateDict: state dict with the original tensors formulation restored + """ + return apply_factory_merges(state_dict, formulation_restore_metadata) + + +def reformulate_single_nd_flattened_tensor( + sh_ten: ShardedTensor, reformulation_metadata: TensorReformulationMetadata +) -> Union[Any, ShardedTensorFactory]: + """Reformulates shapes of a single N-D flattened ShardedTensor. + + We need to define a pair of transformations: + - turn N-D ShardedTensor with original formulation into multiple reformulated ShardedTensors + - merge multiple reformulated loaded torch.Tensors into a single original tensor + Current implementation uses ShardedTensorFactories as a convenient mechanism + for specifying and keeping track of those transformations. + + Args: + sh_ten (ShardedTensor): sharded tensor to reformulate. + reformulation_metadata (TensorReformulationMetadata): metadata needed to + perform the reformulation + + Returns: + ShardedTensorFactory: factory that keeps information how to reformulate + (build) the ShardedTensor and then restore original formulation (merge) + after loading. + """ + rmd = reformulation_metadata + # Data won't be needed - remove unnecessary tensor references + sh_ten = sh_ten.without_data() + + # Based on reformulation_metadata, determine other tensor shapes and metadata + ckpt_axis_fragmentation = rmd.ckpt_reform_global_shape[:-1] + for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation): + assert sh % fragm == 0, (sh_ten, rmd.ckpt_reform_global_shape) + ckpt_local_shape_with_prepended_axis = tuple( + sh // fragm for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation) + ) + assert ( + ckpt_local_shape_with_prepended_axis[: sh_ten.prepend_axis_num] + == (1,) * sh_ten.prepend_axis_num + ), (ckpt_local_shape_with_prepended_axis, sh_ten) + ckpt_local_shape = ckpt_local_shape_with_prepended_axis[sh_ten.prepend_axis_num :] + + # Iterate over reformulated shapes needed by the application and from checkpoint, + # and generate new ShardedTensors that match the checkpoint sharding. + overlap_dim_offsets = [] + assert len(ckpt_axis_fragmentation) == len(sh_ten.axis_fragmentations), ( + ckpt_axis_fragmentation, + sh_ten, + ) + for dim, (app_chunk_dim_offset, ckpt_fragm, app_fragm) in enumerate( + zip( + sh_ten.local_chunk_offset_in_global(), + ckpt_axis_fragmentation, + sh_ten.axis_fragmentations, + ) + ): + # without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units + first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset) + # `math.ceil` argument is an exact offset of the app next shard expressed + # in ckpt_local_shape units + next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1)) + overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset)) + + logger.debug( + f'Generated the following number of overlap shards for each dimension: ' + f'{list(map(len, overlap_dim_offsets))} for fragmentation ckpt ' + f'{ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} ' + f'and chunk offset {sh_ten.local_chunk_offset_in_global()}' + ) + reformulated_sh_tens = {} + for chunk_offset in product(*overlap_dim_offsets): + global_offset = tuple( + chunk_off * chunk_shape + for chunk_off, chunk_shape in zip(chunk_offset, ckpt_local_shape_with_prepended_axis) + ) + reformulated_sh_tens[(global_offset, ckpt_local_shape)] = ShardedTensor( + sh_ten.key, + None, + sh_ten.dtype, + ckpt_local_shape, + rmd.ckpt_orig_global_shape, + global_offset, + ckpt_axis_fragmentation, + sh_ten.replica_id, + sh_ten.prepend_axis_num, + sh_ten.allow_shape_mismatch, + flattened_range=slice(0, rmd.ckpt_reform_global_shape[-1]), # whole ckpt shard + ) + + # Now, we have to define the transformations from application sharding + # to checkpoint sharding. + + @torch.no_grad() + def sh_ten_build_fn(*args, **kwargs): + # Here we simply return the precomputed tensors. + return reformulated_sh_tens + + @torch.no_grad() + def sh_ten_merge_fn(sub_state_dict): + # This is the non-flattened local tensor with original formulation + # that we are going to fill with shards loaded from the checkpoint. + app_non_flat_ten = torch.empty( + sh_ten.local_shape, + dtype=sh_ten.dtype, + device=sh_ten.data.device if sh_ten.data is not None else None, + ) + + assert len(sub_state_dict) > 0 + for (ckpt_global_offset, ckpt_local_shape), ckpt_ten in sub_state_dict.items(): + # For each ckpt shard, we fill the appropriate application shard part + dest_ten = app_non_flat_ten + src_ten = ckpt_ten.view(ckpt_local_shape) + # We don't need narrowing over `prepend_axis_num` axes so we take + # the [sh_ten.prepend_axis_num:] offsets slice + for ( + dim, + offset_for_saved_tensor, + offset_for_current_tensor, + length, + ) in _shards_get_overlap_region_wrt_saved_tensor( + saved_shard=ChunkStorageMetadata( + ckpt_global_offset[sh_ten.prepend_axis_num :], ckpt_local_shape + ), + current_shard=ChunkStorageMetadata( + sh_ten.global_offset[sh_ten.prepend_axis_num :], sh_ten.local_shape + ), + ): + src_ten = src_ten.narrow(dim, offset_for_saved_tensor, length) + dest_ten = dest_ten.narrow(dim, offset_for_current_tensor, length) + dest_ten.copy_(src_ten) + return app_non_flat_ten.flatten()[sh_ten.flattened_range] + + return ShardedTensorFactory( + sh_ten.key, + sh_ten.data, + sh_ten_build_fn, + sh_ten_merge_fn, + sh_ten.replica_id, + sh_ten.flattened_range, + ) diff --git a/megatron/core/dist_checkpointing/strategies/state_dict_saver.py b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py index 7b35209f2131413f348b5ecbd6ebcec1b2a2117a..200304a25a615a87892414c4bad9a6dc46f5ed5b 100644 --- a/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +++ b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py @@ -1,162 +1,247 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -""" State dict saver for PyT Distributed format allowing asynchronous save. """ - -from logging import getLogger -from time import time -from typing import TYPE_CHECKING, Optional, Tuple, cast - -import torch -import torch.distributed as dist -from torch.distributed.checkpoint import CheckpointException -from torch.distributed.checkpoint.default_planner import DefaultSavePlanner -from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata -from torch.distributed.checkpoint.planner import SavePlan, SavePlanner -from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict - -if TYPE_CHECKING: - from .filesystem_async import FileSystemWriterAsync - - -logger = getLogger(__name__) - - -def save_state_dict_async_plan( - state_dict: STATE_DICT_TYPE, - storage_writer: 'FileSystemWriterAsync', - process_group: Optional[dist.ProcessGroup] = None, - coordinator_rank: int = 0, - planner: Optional[SavePlanner] = None, - cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None, -) -> Tuple[Tuple['FileSystemWriterAsync', Metadata, _DistWrapper], SavePlan, bool]: - """ - First stage of saving a state dict to storage. - - This is an async adjustment of torch.distributed.checkpoint.state_dict_saver. - In order to support async save, saving should be split into three parts: - 1. Planning - 2. Actual saving - 3. Finalization - - Out of these, step (2) *must* happen asynchronously. - The first step is realized with this function. - - The planning part consists of several steps, described here: - https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner - - Args: - state_dict (STATE_DICT_TYPE): state dict to save - storage_writer (FileSystemWriterAsync): in current version only an instance of - FileSystemWriterAsync - process_group (dist.ProcessGroup, optional): process group used for save planning - coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0. - planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format - cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional): - Each object of this tuple will be used in the order as following - cached_central_plan (SavePlan): a globally coordinated save plan - cached in the previous iteration - cached_local_plan (SavePlan): a local plan - cached in the previous iteration - validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict - is consistent over iterations - - Returns: Tuple of: - - storage writer (the one passed as input) - - metadata from planning - - distributed wrapper used for planning - The return value of this function should be passed as an input to - `save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning. - """ - cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False) - if cached_ckpt_structure: - cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure - - rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - dist_wrapper = _DistWrapper(process_group, True, coordinator_rank) - if planner is None: - planner = DefaultSavePlanner() - assert planner is not None - - global_metadata = None - logger.debug(f"rank: {rank}, starting state dict save") - local_plan = cached_local_plan - - def local_step(): - nonlocal local_plan - assert planner is not None - # PyTorch 2.4 introduced additional `metadata` argument, - # we have to reference `is_coordinator` args by name - planner.set_up_planner(state_dict, is_coordinator=dist_wrapper.is_coordinator) - storage_writer.set_up_storage_writer(dist_wrapper.is_coordinator) - if not validated_cache_reuse and local_plan is None: - local_plan = planner.create_local_plan() - local_plan = storage_writer.prepare_local_plan(local_plan) - return local_plan - - def global_step(all_local_plans): - nonlocal global_metadata - assert planner is not None - all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) - all_local_plans = storage_writer.prepare_global_plan(all_local_plans) - return all_local_plans - - # Execute local and global planning - start_plan = time() - if validated_cache_reuse and cached_central_plan: - logger.debug(f"rank: {rank}, Passed cache reusable") - local_step() - central_plan = cached_central_plan - else: - central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step) - central_plan = planner.finish_plan(central_plan) - end_plan = time() - logger.debug(f"rank: {rank}, plan time: {end_plan - start_plan}") - # Prepare async writing of tensors. - # The `storage_writer` will store the information about tensors it needs to save - start = time() - storage_writer.prepare_write_data(central_plan, planner) - end = time() - logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}") - return ( - (storage_writer, cast(Metadata, global_metadata), dist_wrapper), - central_plan, - local_plan, - cached_central_plan == central_plan, - ) - - -def save_state_dict_async_finalize( - storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper -) -> None: - """ - Finalization of save_state_dict_async_plan. - - The input arguments are the same as the save_state_dict_async_plan output, - the `write_results` are retrieved from the storage_writer. - - Args: - storage_writer (FileSystemWriterAsync): storage writer used for planning - global_metadata (Metadata): metadata created during planning - dist_wrapper (_DistWrapper): distributed wrapper created during planning - - Returns: None - """ - write_results = storage_writer.retrieve_write_results() - - # Gather the write results that will be saved to the metadata file. - gather_start = time() - all_results = dist_wrapper.gather_object(write_results) - gather_end = time() - logger.debug(f"{gather_end}, {torch.distributed.get_rank()}, gather: {gather_end-gather_start}") - - # Store the metadata on coordinator rank - if dist_wrapper.is_coordinator: - node_failures = _get_failure_dict(all_results) - if len(node_failures) == 0: - assert global_metadata is not None - write_start = time() - storage_writer.finish(global_metadata, all_results) - write_end = time() - logger.debug(f"{write_end}, metadata_write: {write_end - write_start}") - else: - raise CheckpointException("write", node_failures) +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +""" State dict saver for PyT Distributed format allowing asynchronous save. """ + +from logging import getLogger +from time import time +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint import CheckpointException +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata +from torch.distributed.checkpoint.planner import SavePlan, SavePlanner +from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict + +if TYPE_CHECKING: + from .filesystem_async import FileSystemWriterAsync + from .torch import MCoreSavePlanner + + +logger = getLogger(__name__) + +from dataclasses import fields + + +def _compare_dataclasses(obj1, obj2): + if type(obj1) != type(obj2): + return f"Objects are of different types: {type(obj1)} and {type(obj2)}" + + differences = [] + for field in fields(obj1): + value1 = getattr(obj1, field.name) + value2 = getattr(obj2, field.name) + if value1 != value2: + differences.append(f"{field.name}: {value1} != {value2}") + + return differences if differences else "All fields are equal" + + +def save_state_dict_async_plan( + state_dict: STATE_DICT_TYPE, + storage_writer: 'FileSystemWriterAsync', + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + planner: Optional[Union[SavePlanner, 'MCoreSavePlanner']] = None, + cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None, + loaded_all_plans: Optional[List[SavePlan]] = None, +) -> Tuple[Tuple['FileSystemWriterAsync', Union[Metadata, None], _DistWrapper], SavePlan, bool]: + """ + First stage of saving a state dict to storage. + + This is an async adjustment of torch.distributed.checkpoint.state_dict_saver. + In order to support async save, saving should be split into three parts: + 1. Planning + 2. Actual saving + 3. Finalization + + Out of these, step (2) *must* happen asynchronously. + The first step is realized with this function. + + The planning part consists of several steps, described here: + https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner + + Args: + state_dict (STATE_DICT_TYPE): state dict to save + storage_writer (FileSystemWriterAsync): in current version only an instance of + FileSystemWriterAsync + process_group (dist.ProcessGroup, optional): process group used for save planning + coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0. + planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format + cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional): + Each object of this tuple will be used in the order as following + cached_central_plan (SavePlan): a globally coordinated save plan + cached in the previous iteration + cached_local_plan (SavePlan): a local plan + cached in the previous iteration + validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict + is consistent over iterations + + Returns: Tuple of: + - storage writer (the one passed as input) + - metadata from planning (or None if we reuse cached global metadata) + - distributed wrapper used for planning + The return value of this function should be passed as an input to + `save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning. + """ + cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False) + if cached_ckpt_structure: + cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure + + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + dist_wrapper = _DistWrapper(process_group, True, coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + assert planner is not None + + global_metadata = None + logger.debug(f"rank: {rank}, starting state dict save") + local_plan = cached_local_plan + global_md_verify_reuse = False + + def local_step(): + nonlocal local_plan + assert planner is not None + # PyTorch 2.4 introduced additional `metadata` argument, + # we have to reference `is_coordinator` args by name + planner.set_up_planner(state_dict, is_coordinator=dist_wrapper.is_coordinator) + storage_writer.set_up_storage_writer(dist_wrapper.is_coordinator) + if not validated_cache_reuse and local_plan is None: + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + return local_plan + + def global_step(all_local_plans): + nonlocal global_metadata + assert planner is not None + all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + return all_local_plans + + # Execute local and global planning + # Ideally we want to use the cached plan. Otherwise if the planner and storage_writer + # allow it (`can_run_decentralized_global_plan`) we gather the plans to create + # the metadata but prepare the plans independently on each rank. + # In the worst case we have to reduce_scatter all the plans. + start_plan = time() + if validated_cache_reuse and cached_central_plan: + logger.debug(f"rank: {rank}, Passed cache reusable") + local_step() + central_plan = cached_central_plan + elif getattr(planner, 'can_run_decentralized_global_plan', False) and getattr( + storage_writer, 'can_run_decentralized_global_plan', False + ): + local_plan = local_step() + global_md_verify_reuse = verify_global_md_reuse( + loaded_all_plans, local_plan, rank, dist_wrapper + ) + + if not loaded_all_plans or not global_md_verify_reuse: + all_local_plans = dist_wrapper.gather_object(local_plan) + if dist_wrapper.is_coordinator: + _, global_metadata = planner.create_global_plan(all_local_plans) + global_metadata.all_local_plans = all_local_plans + else: + logger.debug(f"rank: {rank}, Passed cached global metadata") + global_metadata = None + local_plan = planner.create_decentralized_global_plan(local_plan) + local_plan = storage_writer.prepare_decentralized_global_plan(local_plan) + central_plan = local_plan + else: + central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step) + central_plan = planner.finish_plan(central_plan) + end_plan = time() + logger.debug(f"rank: {rank}, plan time: {end_plan - start_plan}") + # Prepare async writing of tensors. + # The `storage_writer` will store the information about tensors it needs to save + start = time() + storage_writer.prepare_write_data(central_plan, planner) + end = time() + logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}") + return ( + (storage_writer, global_metadata, dist_wrapper), + central_plan, + local_plan, + cached_central_plan == central_plan, + global_md_verify_reuse, + ) + + +def verify_global_md_reuse( + loaded_all_plans: List[SavePlan], local_plan: SavePlan, rank: int, dist_wrapper: _DistWrapper +) -> bool: + """ + Verifies that global metadata reuse is possible by checking the loaded plans from the + checkpoint are consistent, which means we have the same settings when resuming training. + Args: + loaded_all_plans: List[SavePlan], The loaded plans from the checkpoint + (stored in checkpoint metadata). + local_plan: SavePlan, The local save plan. + rank: Current process rank. + dist_wrapper (_DistWrapper): distributed wrapper created during planning + + Returns: True iff the global metadata reuse is possible. + + """ + logger.debug(f"verifying reuse of global metadata") + if not loaded_all_plans: + global_md_verify_reuse = False + logger.debug("loaded global metadata reuse verification: no loaded plans passed") + + elif len(loaded_all_plans) == dist_wrapper.get_world_size(): + local_verify_reuse = all( + getattr(local_plan, f.name) == getattr(loaded_all_plans[rank], f.name) + for f in fields(local_plan) + if f.name != 'storage_data' + ) + + if not local_verify_reuse: + logger.debug( + f"local_verify_reuse is False: diffs -" + f" {_compare_dataclasses(local_plan, loaded_all_plans[rank])}" + ) + all_results = torch.tensor([local_verify_reuse], dtype=torch.int, device='cuda') + torch.distributed.all_reduce(all_results, op=torch.distributed.ReduceOp.MIN) + # Check if all reduced results are True + global_md_verify_reuse = all_results.item() == 1 + else: + global_md_verify_reuse = False + return global_md_verify_reuse + + +def save_state_dict_async_finalize( + storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper +) -> None: + """ + Finalization of save_state_dict_async_plan. + + The input arguments are the same as the save_state_dict_async_plan output, + the `write_results` are retrieved from the storage_writer. + + Args: + storage_writer (FileSystemWriterAsync): storage writer used for planning + global_metadata (Metadata): metadata created during planning + dist_wrapper (_DistWrapper): distributed wrapper created during planning + + Returns: None + """ + write_results = storage_writer.retrieve_write_results() + + # Gather the write results that will be saved to the metadata file. + gather_start = time() + all_results = dist_wrapper.gather_object(write_results) + gather_end = time() + logger.debug(f"{gather_end}, {torch.distributed.get_rank()}, gather: {gather_end-gather_start}") + + # Store the metadata on coordinator rank + if dist_wrapper.is_coordinator: + node_failures = _get_failure_dict(all_results) + if len(node_failures) == 0: + assert global_metadata is not None + write_start = time() + storage_writer.finish(global_metadata, all_results) + write_end = time() + logger.debug(f"{write_end}, metadata_write: {write_end - write_start}") + else: + raise CheckpointException("write", node_failures) diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index ea95254ad1ad1f5c73448e55d1d51f3983d31ff8..ec647ebd1cc7df6ef266f48d5a28d6e7bdea6721 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -1,939 +1,1010 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Strategies using PyTorch distributed.checkpoint as an underlying format. """ -import io -import os -import pickle -import warnings -from collections import ChainMap, defaultdict -from dataclasses import dataclass -from itertools import product -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast - -import torch -from packaging.version import Version as PkgVersion -from torch.distributed import checkpoint -from torch.distributed._shard.metadata import ShardMetadata -from torch.distributed._shard.sharded_tensor import Shard -from torch.distributed._shard.sharded_tensor import ShardedTensor as TorchShardedTensor -from torch.distributed._shard.sharded_tensor import ShardedTensorMetadata, TensorProperties -from torch.distributed.checkpoint import ( - BytesStorageMetadata, - DefaultLoadPlanner, - DefaultSavePlanner, - FileSystemReader, - FileSystemWriter, - LoadPlan, - Metadata, - ReadItem, - SavePlan, - TensorStorageMetadata, - WriteItem, -) -from torch.distributed.checkpoint._nested_dict import FLATTEN_MAPPING, unflatten_state_dict -from torch.distributed.checkpoint._traverse import OBJ_PATH, traverse_state_dict -from torch.distributed.checkpoint.metadata import Metadata -from torch.distributed.checkpoint.planner_helpers import _create_write_items - -from ...utils import get_torch_version, is_torch_min_version -from ..core import CheckpointingException -from ..dict_utils import nested_values -from ..mapping import ( - ShardedBase, - ShardedObject, - ShardedStateDict, - ShardedTensor, - StateDict, - is_main_replica, -) -from .async_utils import AsyncRequest -from .base import ( - AsyncSaveShardedStrategy, - LoadShardedStrategy, - StrategyAction, - register_default_strategy, -) -from .filesystem_async import FileSystemWriterAsync -from .resharding import ( - TensorReformulationMetadata, - apply_nd_flattened_tensors_reformulation, - is_nd_flattened_tensor, - nd_flattened_tensor_reformulated_global_shape, - restore_nd_flattened_tensors_formulation, -) -from .state_dict_saver import save_state_dict_async_finalize, save_state_dict_async_plan - -try: - if not torch.cuda.is_available(): - raise ImportError - from transformer_engine.pytorch.float8_tensor import Float8Tensor - - HAVE_TE = True -except ImportError: - HAVE_TE = False - -try: - from torch.distributed._tensor import DTensor - - HAVE_DTENSOR = True -except ImportError: - HAVE_DTENSOR = False - -_metadata_fn: str = ".metadata" - - -def register_default_torch_strategies(): - """Register default strategies related to PyT Distributed backend.""" - register_default_strategy( - StrategyAction.LOAD_SHARDED, 'torch_dist', 1, TorchDistLoadShardedStrategy() - ) - register_default_strategy( - StrategyAction.SAVE_SHARDED, 'torch_dist', 1, TorchDistSaveShardedStrategy('torch_dist', 1) - ) - - -logger = getLogger(__name__) - - -def flatten_state_dict( - state_dict: ShardedStateDict, -) -> Tuple[ShardedStateDict, Dict[str, OBJ_PATH]]: - """Flattens state dict into a single level dict. - - It's a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict - which also accepts ShardedBase tensors as terminal objects - - Args: - state_dict (ShardedStateDict): state dict to be flattened - - Returns (tuple): flattened state dict and a mapping allowing to recreate the original one - - """ - flattened = {} - mappings = {} - - def flat_copy(path: OBJ_PATH, value: Any) -> None: - new_fqn = ".".join(map(str, path)) - if new_fqn in flattened: - raise ValueError(f"duplicated flatten key {new_fqn}") - flattened[new_fqn] = value - mappings[new_fqn] = path - - traverse_state_dict(state_dict, flat_copy, lambda x: isinstance(x, (torch.Tensor, ShardedBase))) - return flattened, mappings - - -def sharded_tensor_to_torch_sharded_tensor( - sh_tens: List[ShardedTensor], rank: Optional[int] = None -) -> TorchShardedTensor: - """Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks. - - On high-level, this function follows the logic of - torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor. - Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore) - as attributes for further restoration in `_unwrap_pyt_sharded_tensor`. - - NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor. - The only local irregularities could be introduced with a `flattened_range` attribute. - - This function handles 3 different type of ShardedTensors: - 1. Non-flat regular ShardedTensors (`not has_flattened_range`) - 2. 1D flattened ShardedTensors (`is_flattened_range_1d`) - 3. N-D flattened ShardedTensors (`has_flattened_range`) - - (1) and (2) type are saved according to their original shape. - Type (3) however requires global shape adjustment for efficiency: - we treat [X, Y, Z] global shape tensor with local shape [x, y, z] - as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis - partitioned according to `flattened_range` slices. - This will need special handling while resharding. - - Args: - sh_tens (List[ShardedTensor]): list of sharded tensors to convert - rank (int, optional): current process rank passed to PyT ShardedTensor. - If None, assumes rank in the default pg. - - Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards. - - """ - if rank is None: - rank = torch.distributed.get_rank() - - some_sh_ten = sh_tens[0] - has_flattened_range = some_sh_ten.flattened_range is not None - is_flattened_range_1d = has_flattened_range and len(some_sh_ten.global_shape) == 1 - - for sh_ten in sh_tens: - assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens - if not sh_ten.data.is_contiguous(): - sh_ten.data = sh_ten.data.contiguous() - - local_global_offsets = {} - - prepend_axis_num = sh_tens[0].prepend_axis_num - # Determine local shards according to tensor type (see docs) - if is_flattened_range_1d: - # Type (2) case: 1D flattened ShardedTensors - for sh_ten in sh_tens: - assert len(sh_ten.global_offset) == 1, sh_ten - assert sh_ten.prepend_axis_num == 0, sh_ten - local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten) - - global_shape = some_sh_ten.global_shape - offsets_shape = ( - some_sh_ten.local_shape - ) # local shape is not flattened, we need it for chunk offsets - - local_shards = [ - Shard.from_tensor_and_offsets( - sh_ten.data, - [ - sh_ten.global_offset[0] + sh_ten.flattened_range.start - ], # additional flattened offset - rank, - ) - for sh_ten in sh_tens - ] - - elif has_flattened_range: - # Type (3) case: N-D flattened ShardedTensors - for sh_ten in sh_tens: - local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append( - sh_ten - ) - assert sh_ten.data.ndim == 1, sh_ten - sh_ten.data = sh_ten.data.view((1,) * len(sh_ten.global_shape) + (-1,)) - - # Global shape reformulation: - global_shape = nd_flattened_tensor_reformulated_global_shape(some_sh_ten) - offsets_shape = (1,) * len( - some_sh_ten.global_shape - ) # reformulated global shape has shape equal ti number of local chunks - - local_shards = [ - Shard.from_tensor_and_offsets( - sh_ten.data, - list( - sh_ten.local_chunk_offset_in_global() + (sh_ten.flattened_range.start,) - ), # additional flattened offset - rank, - ) - for sh_ten in sh_tens - ] - else: - # Type (1) case: non-flat regular ShardedTensors - for sh_ten in sh_tens: - local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten) - sh_ten.data = sh_ten.data.view( - (1,) * prepend_axis_num + sh_ten.local_shape - ) # adjust to prepended_axis_num - - global_shape = some_sh_ten.global_shape - offsets_shape = some_sh_ten.data.shape # includes prepended axes - - local_shards = [ - Shard.from_tensor_and_offsets( - sh_ten.data, list(sh_ten.global_offset), rank # simple case - ) - for sh_ten in sh_tens - ] - - # Create a ShardedTensor without invoking communication. Determine global shards - world_size = torch.distributed.get_world_size() - shard_metadata = [] - # NOTE: here we assume a regular grid of shards - for fragment_offsets in product(*map(range, some_sh_ten.axis_fragmentations)): - offset = tuple(map(lambda x: x[0] * x[1], zip(fragment_offsets, offsets_shape))) - if offset in local_global_offsets: - # local shard - placement = f"rank:{rank}/cuda" - for sh_ten in local_global_offsets[offset]: - if is_flattened_range_1d: - offset = (sh_ten.global_offset[0] + sh_ten.flattened_range.start,) - size = sh_ten.data.shape - elif has_flattened_range: - assert offset == sh_ten.local_chunk_offset_in_global() - # This is not an actual offset, but an offset of the whole shard - # This is needed for a PyT Dist internal integrity check - offset = sh_ten.local_chunk_offset_in_global() + (0,) - size = (1,) * len(offsets_shape) + global_shape[-1:] - else: - size = sh_ten.data.shape - shard_metadata.append(ShardMetadata(offset, size, placement)) - - else: - # pylint: disable=line-too-long - # for shards from other ranks we provide simplistic data - this information will be discarded - # during TorchShardedTensor._init_from_local_shards_and_global_metadata call. - # Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size. - # The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS. - placement = f"rank:{(rank + 1) % world_size}/cuda" - if has_flattened_range and not is_flattened_range_1d: - offset = offset + (0,) - size = (1,) * len(offsets_shape) + global_shape[-1:] - else: - size = offsets_shape - shard_metadata.append(ShardMetadata(offset, size, placement)) - - tensor = some_sh_ten.data - sharded_tensor_metadata = ShardedTensorMetadata( - shards_metadata=shard_metadata, - size=torch.Size(global_shape), - tensor_properties=TensorProperties( - dtype=tensor.dtype, - layout=tensor.layout, - requires_grad=tensor.requires_grad, - memory_format=torch.contiguous_format, - pin_memory=tensor.is_pinned(), - ), - ) - pyt_sh_ten = TorchShardedTensor._init_from_local_shards_and_global_metadata( - local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=None - ) - # Store MCore related data as PyTShardedTensor attribute. - # This won't be stored in the checkpoint, only for runtime purposes - pyt_sh_ten.mcore_sh_ten = sh_ten.without_data() - pyt_sh_ten.mcore_metadata = {} - if has_flattened_range and not is_flattened_range_1d: - pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape - return pyt_sh_ten - - -def mcore_to_pyt_state_dict( - state_dict: Dict[str, List[ShardedBase]], - is_loading: bool = False, - init_device: torch.device = torch.device("cpu"), -) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]: - """Convert state dict with ShardedTensors and ShardedObjects - to state dict compatible with PyT Dist format. - - Operates in-place and returns the original state dict. - - Args: - state_dict (Dict[str, List[ShardedBase]]): flattened state dict, where values - are lists of either ShardedTensor or ShardedObjects. - is_loading (bool, optional): flag indicating if loading or saving. Defaults to False. - init_device (torch.device, optional): device to initialize potentially missing tensors - during loading. Defaults to 'cpu'. - - Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values - converted either into PyT ShardedTensors or io.BytesIO. - - """ - rank = torch.distributed.get_rank() - pyt_state_dict = {} - - def _mcore_to_torch_sharded_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor: - """Build a PyT ShardedTensor from given shards. - - During loading: - - if data is None, initialize it with an empty tensor (will be used to copy the data into) - - if `allow_shape_mismatch` is True, the data is initialized with zeros - prior to loading (not all parts of the tensor will be read from the checkpoint) - """ - assert all(isinstance(sh_ten, ShardedTensor) for sh_ten in sh_tens), sh_tens - for sh_ten in sh_tens: - if sh_ten.data is None: - if is_loading: - sh_ten.init_data( - init_device, - init_fn=torch.zeros if sh_ten.allow_shape_mismatch else torch.empty, - ) - else: - raise CheckpointingException(f'`data` attr is None for {sh_ten}') - else: - sh_ten.data = sh_ten.data.detach() - if sh_ten.allow_shape_mismatch and is_loading: - sh_ten.data.zero_() - - torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(sh_tens, rank) - torch_sh_ten.key = sh_tens[0].key - return torch_sh_ten - - def _mcore_to_torch_sharded_object(sh_objs: List[ShardedObject]) -> io.BytesIO: - """Build io.BytesIO from given sharded objects data.""" - assert all(isinstance(sh_obj, ShardedObject) for sh_obj in sh_objs), sh_objs - serialized_data = io.BytesIO() - torch.save([sh_obj.data for sh_obj in sh_objs], serialized_data) - return serialized_data - - for k, v in state_dict.items(): - if isinstance(v[0], ShardedTensor): - v = cast(List[ShardedTensor], v) - pyt_state_dict[k] = _mcore_to_torch_sharded_tensor(v) - else: - v = cast(List[ShardedObject], v) - pyt_state_dict[k] = _mcore_to_torch_sharded_object(v) - - return pyt_state_dict - - -def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]: - """Unwrap tensor from PyT ShardedTensor instance. - - If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor) - then the tensor has additional singleton dimensions which should be squeezed. - """ - mcore_sh_ten = sh_ten.mcore_sh_ten - ret_tensors = [] - for sh in sh_ten.local_shards(): - ten = sh.tensor - if mcore_sh_ten.flattened_range is not None: - assert ten.shape[:-1] == (1,) * (len(ten.shape) - 1), ten.shape - ten = ten.view(-1) - else: - for _ in range(mcore_sh_ten.prepend_axis_num): - ten = ten.squeeze(0) - ret_tensors.append(ten) - return ret_tensors - - -def _replace_state_dict_keys_with_sharded_keys( - sharded_state_dict: ShardedStateDict, keep_only_main_replica: bool = False -) -> Tuple[Dict[str, List[ShardedBase]], FLATTEN_MAPPING, Dict[str, List[str]]]: - """Group ShardedBase objects by keys and - return mappings required for recreating the original dict.""" - flat_sd, flat_mapping = flatten_state_dict(sharded_state_dict) - rename_mapping = defaultdict(list) - new_flat_sd = defaultdict(list) - for k, sh_base in flat_sd.items(): - assert isinstance(sh_base, ShardedBase), type(sh_base) - key = sh_base.unique_key if isinstance(sh_base, ShardedObject) else sh_base.key - if is_main_replica(sh_base.replica_id) or not keep_only_main_replica: - rename_mapping[key].append(k) - new_flat_sd[key].append(sh_base) - return new_flat_sd, flat_mapping, rename_mapping - - -def _replace_sharded_keys_with_state_dict_keys( - state_dict: Dict[str, List[Union[torch.Tensor, io.BytesIO]]], - flat_mapping: FLATTEN_MAPPING, - rename_mapping: Dict[str, List[str]], -): - """Inverse of _replace_state_dict_keys_with_sharded_keys.""" - recovered_sd = {} - for k, tensors in state_dict.items(): - assert len(tensors) == len(rename_mapping[k]) - for ten, recovered_k in zip(tensors, rename_mapping[k]): - recovered_sd[recovered_k] = ten - - return unflatten_state_dict(recovered_sd, flat_mapping) - - -def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, list, Any]): - """Recursively update `x` keys, based on `keys_template`.""" - if isinstance(keys_template, dict): - assert isinstance(x, dict), type(x) - for k, v in keys_template.items(): - if not isinstance(k, str): - assert str(k) in x, (k, x.keys) - x[k] = x.pop(str(k)) - _restore_dict_types(x[k], v) - elif isinstance(keys_template, list): - assert isinstance(x, list), type(x) - for x_val, templ_val in zip(x, keys_template): - _restore_dict_types(x_val, templ_val) - - -@dataclass(frozen=True) -class MCoreSavePlan(SavePlan): - """SavePlan with MCore specific data.""" - - mcore_data: Dict[str, Dict[str, Any]] = None # Mcore related data about each tensor - - -class MCoreSavePlanner(DefaultSavePlanner): - """Differs with the default planner by saving BytesIO objects on all ranks. - - In the integration of MCore with PyT Distributed format, BytesIO objects - come from ShardedObjects, which should be treated as separate objects on each rank - (not common on all ranks). - - Also, the objects are already packed in io.BytesIO, so no need to redo it - in transform_object. - """ - - def __init__( - self, - *args, - dedup_replicated_tensors: Optional[bool] = None, - nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None, - **kwargs, - ) -> None: - # `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings - # during saving. - if get_torch_version() <= PkgVersion("2.2"): - kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors - super().__init__(*args, **kwargs) - self.nd_flattened_global_shapes = nd_flattened_global_shapes or {} - - def create_local_plan(self) -> SavePlan: - """Adds IOBytes write request on non-coordinator ranks.""" - - # NOTE: for PyT 2.4.0a0 we can't rely on `create_default_local_save_plan` because - # some alpha versions (specifically 2.4.0a0+f70bd71a48 in 24.06 NGC PyTorch container) - # add iobytes request only on coordinator ranks and some alpha versions - # (specifically 2.4.0a0+3bcc3cddb5 in 24.07 NGC PyTorch container) - # add those requests on all ranks. We inline a simplified version of this method below. - write_items = [] - for fqn, obj in self.state_dict.items(): - assert not HAVE_DTENSOR or not isinstance( - obj, DTensor - ) # translation from MCore ShardedTensors shouldn't result in DTensors - # Create write requests for tensor and bytes values. - # For MCore, these should be already non-duplicates. - write_items += _create_write_items(fqn, obj) - - self.plan = MCoreSavePlan( - items=write_items, - planner_data=self.mappings, - mcore_data={ - k: sh_ten.mcore_metadata - for k, sh_ten in self.state_dict.items() - if isinstance(sh_ten, TorchShardedTensor) - }, - ) - return self.plan - - def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]: - """Merges MCore data for all plans.""" - global_plan, metadata = super().create_global_plan(all_plans) - metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans))) - return global_plan, metadata - - def transform_object(self, write_item: WriteItem, object: Any): - """Make no transformations - bytes objects are already serialized.""" - return object - - -class MCoreLoadPlanner(DefaultLoadPlanner): - """Adds global shape validation to the default planner. - - If global shape validation can be ignored (shouldn't!), the default - load planner can be used. - """ - - def __init__( - self, *args, shapes_validation_sharded_tensors: Iterable[ShardedTensor] = (), **kwargs - ) -> None: - super().__init__(*args, **kwargs) - self.shapes_validation_sharded_tensors = shapes_validation_sharded_tensors - self._intermediate_read_item_and_target: Optional[Tuple[ReadItem, torch.Tensor]] = None - - def _validate_global_shapes(self, metadata, sharded_tensors): - for sh_ten in sharded_tensors: - if sh_ten.key not in metadata.state_dict_metadata: - raise KeyError( - f"{sh_ten.key} from model not in state dict:" - f" {sorted(metadata.state_dict_metadata.keys())}" - ) - loaded_shape = metadata.state_dict_metadata[sh_ten.key].size - if not is_nd_flattened_tensor(sh_ten): - expected_shape = sh_ten.global_shape - else: - expected_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten) - if loaded_shape != expected_shape: - _msg = ( - f'Global shape mismatch for loaded ({loaded_shape})' - f' and expected ({expected_shape}) tensor' - f' for key {sh_ten.key}' - ) - raise CheckpointingException(_msg) - - def create_local_plan(self) -> LoadPlan: - """Runs additional shapes validation.""" - self._validate_global_shapes(self.metadata, self.shapes_validation_sharded_tensors) - return super().create_local_plan() - - def resolve_tensor(self, read_item: ReadItem): - """Override to add FP8 support. - - Narrowing the Float8Tensor can create incontiguous tensors and there are - no `copy` kernels for such cases. This method creates a contiguous FP8 - tensors so that the subsequent `copy_` in FileSystemReader succeeds. - Note that this requires tracking the original tensor - (as `self._intermediate_read_item_and_target` attribute) - and restoring it in `commit_tensor` method. - """ - target_tensor = super().resolve_tensor(read_item) - if ( - not target_tensor.is_contiguous() - and HAVE_TE - and isinstance(target_tensor, Float8Tensor) - ): - self._intermediate_read_item_and_target = (read_item, target_tensor) - target_tensor = Float8Tensor.make_like( - target_tensor, data=target_tensor._data.contiguous() - ) - return target_tensor - - def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: - """Restores the original FP8 tensor saved in `resolve_tensor`.""" - if self._intermediate_read_item_and_target is not None: - interm_read_item, target_tensor = self._intermediate_read_item_and_target - assert ( - interm_read_item is read_item - ), '`commit_tensor` method should be called right after `resolve_tensor`' - target_tensor.copy_(tensor) - tensor = target_tensor - self._intermediate_read_item_and_target = None - return super().commit_tensor(read_item, tensor) - - -class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy): - """Async save strategy for the PyT Distributed format. - - The idea is to translate MCore ShardedTensors into PyT ShardedTensors - and use the async-adjusted torch.distributed.checkpoint saving mechanism - provided by the FileSystemWriterAsync writer. - """ - - def __init__( - self, - backend: str, - version: int, - keep_only_main_replica: bool = True, - thread_count: int = 2, - cached_metadata: bool = False, - separation_hint: str = None, - ): - """Adds parameters specific to PyT Distributed format - Args: - backend (str): format backend string - version (int): format version - keep_only_main_replica (bool, optional): PyT Distributed has a mechanism - for deduplication, but replica_id aware deduplication is more coherent. - Default is True (recommended to keep it). - thread_count (int, optional): threads to use during saving. - Affects the number of files in the checkpoint (saving ranks * num_threads). - cached_metadata (bool, optional): Enables using cached global metadata to avoid - gathering local metadata every checkpointing invocation - separation_hint(str, optional): If provided, all tensors whose keys have this - prefix will be saved to a separate file. - """ - super().__init__(backend, version) - self.keep_only_main_replica = keep_only_main_replica - self.thread_count = thread_count - - # Cached SavePlans to skip plan in `save_state_dict_async_plan` - # cached outcome of `SavePlan.prepare_global_plan`, - # which aggregates local plans from all ranks - self.cached_central_plan: SavePlan = None - # cached outcome of `SavePlan.prepare_local_plan` describes how local state_dict is written - self.cached_local_plan: SavePlan = None - # Cached global metadata, only `coordinator` for dist-ckpt holds - # if central plans are consistent over iters - self.cached_global_metadata: Metadata = None - # This variable records if the ckpt structures are consistent - # so the following checkpoint savings reuse `cached_global_metadata` - self.validated_cache_reuse: bool = False - # The knob to enable cached metadata communication in saving - self.use_cached_ckpt_structure: bool = cached_metadata - - self.separation_hint = separation_hint - - def async_save( - self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path - ) -> AsyncRequest: - """Translates MCore ShardedTensors to PyT ShardedTensors & saves in PyT Distributed format. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict to save - checkpoint_dir (Path): checkpoint directory - - Returns: None - """ - # Translate the state dict - (sharded_state_dict, flat_mapping, rename_mapping) = ( - _replace_state_dict_keys_with_sharded_keys( - sharded_state_dict, self.keep_only_main_replica - ) - ) - pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False) - # Use PyT saving mechanism - writer = FileSystemWriterAsync( - checkpoint_dir, separation_hint=self.separation_hint, thread_count=self.thread_count - ) - # This should be set differently if we run in a smaller process group than the default - coordinator = 0 - # Try twice to validate the generated `central_plan` is the same across iterations - # If so, reuse `cached_central_plan` and `cached_global_metadata` - # From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata` - # (return None) so `self.cached_global_metadata` is reused - args_cached_plans = None - if self.use_cached_ckpt_structure: - args_cached_plans = ( - self.cached_central_plan, - self.cached_local_plan, - self.validated_cache_reuse, - ) - - ( - save_state_dict_ret, - self.cached_central_plan, - self.cached_local_plan, - self.validated_cache_reuse, - ) = save_state_dict_async_plan( - pyt_state_dict, - writer, - None, - coordinator, - planner=MCoreSavePlanner(dedup_replicated_tensors=not self.keep_only_main_replica), - cached_ckpt_structure=args_cached_plans, - ) - rank = torch.distributed.get_rank() - if self.use_cached_ckpt_structure: - if self.validated_cache_reuse: - logger.debug(f"rank: {rank}, cache validated") - if save_state_dict_ret[1]: # when global_metadata is not cached - self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata - # Only Coordinator rank holds cached global_metadata - # (None is returned for global_metadata) - elif coordinator == rank: - logger.debug(f"rank: {rank}, reuse metadata, {save_state_dict_ret[1]}") - save_state_dict_ret = list(save_state_dict_ret) - save_state_dict_ret[1] = self.cached_global_metadata - - return self._get_save_and_finalize_callbacks(writer, save_state_dict_ret) - - def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest: - save_fn_args = writer.get_save_function_and_args() - save_fn, save_args = save_fn_args - - def finalize_fn(): - save_state_dict_async_finalize(*save_state_dict_ret) - torch.distributed.barrier() - - return AsyncRequest(save_fn, save_args, [finalize_fn]) - - def can_handle_sharded_objects(self): - return True - - -def get_reformulation_metadata( - sharded_state_dict: ShardedStateDict, checkpoint_dir: Path -) -> Dict[str, TensorReformulationMetadata]: - """Reads MCore data for N-D flattened tensors from checkpoint metadata during ckpt load. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict to load - checkpoint_dir (Path): checkpoint directory - - Returns: - Dict[str, TensorReformulationMetadata] - dictionary that maps keys of every - N-D flattened tensor from the sharded_state_dict to its original global shape - as stored in `mcore_data` in the checkpoint. - """ - ckpt_metadata = FileSystemReader(checkpoint_dir).read_metadata() - reformulation_metadata = {} - for sh_ten in nested_values(sharded_state_dict): - if not is_nd_flattened_tensor(sh_ten): - continue - try: - ckpt_global_shape = ckpt_metadata.mcore_data[sh_ten.key][ - 'nd_reformulated_orig_global_shape' - ] - except KeyError as e: - raise CheckpointingException( - f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} ' - f'in checkpoint metadata: {ckpt_metadata.mcore_data}' - ) from e - - reformulation_metadata[sh_ten.key] = TensorReformulationMetadata( - ckpt_global_shape, ckpt_metadata.state_dict_metadata[sh_ten.key].size - ) - return reformulation_metadata - - -class TorchDistLoadShardedStrategy(LoadShardedStrategy): - """Basic load strategy for the PyT Distributed format.""" - - def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: - """Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict with mapping - information to instruct loading - checkpoint_dir (Path): checkpoint directory - - Returns: loaded state dict - """ - # Apply N-D tensors resharding - sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation( - sharded_state_dict, get_reformulation_metadata(sharded_state_dict, checkpoint_dir) - ) - - flexible_shape_sharded_tensors = [ - sh_ten - for sh_ten in nested_values(sharded_state_dict) - if isinstance(sh_ten, ShardedTensor) and not sh_ten.allow_shape_mismatch - ] - - orig_sharded_state_dict = sharded_state_dict - # MCore state dict to PyT Distributed compatible - (sharded_state_dict, flat_mapping, rename_mapping) = ( - _replace_state_dict_keys_with_sharded_keys(sharded_state_dict) - ) - pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True) - # Load PyT Distributed format - checkpoint.load_state_dict( - pyt_state_dict, - FileSystemReader(checkpoint_dir), - planner=MCoreLoadPlanner( - shapes_validation_sharded_tensors=flexible_shape_sharded_tensors - ), - ) - pyt_state_dict = cast( - Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict - ) - # Unwrap ShardedTensors and return to original state dict - mcore_state_dict = { - k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v) - for k, v in pyt_state_dict.items() - } - mcore_state_dict = _replace_sharded_keys_with_state_dict_keys( - mcore_state_dict, flat_mapping, rename_mapping - ) - _restore_dict_types(mcore_state_dict, orig_sharded_state_dict) - # Apply N-D tensors resharding postprocessing - mcore_state_dict = restore_nd_flattened_tensors_formulation( - mcore_state_dict, formulation_restore_data - ) - return mcore_state_dict - - def load_tensors_metadata(self, checkpoint_dir: Path, metadata: Metadata = None): - """Uses tensors metadata stored in the metadata file.""" - if metadata is None: - fs_reader = FileSystemReader(checkpoint_dir) - metadata = fs_reader.read_metadata() - - mcore_data = getattr(metadata, 'mcore_data', {}) - sharded_metadata = {} - for k, tp in metadata.state_dict_metadata.items(): - if not isinstance(tp, TensorStorageMetadata): - continue # load only tensors - - nd_orig_global_shape = mcore_data.get(k, {}).get('nd_reformulated_orig_global_shape') - if nd_orig_global_shape is None: - # Regular tensor - sharded_metadata[k] = ShardedTensor.from_rank_offsets( - k, torch.empty(tp.size, **tp.properties.__dict__, device='meta') - ).without_data() - else: - # N-D flattened tensor - unflat_ten = torch.empty( - nd_orig_global_shape, **tp.properties.__dict__, device='meta' - ) - flat_ten = unflat_ten.flatten() - sharded_metadata[k] = ShardedTensor.from_rank_offsets_flat( - k, - flat_ten, - unflat_ten.shape, - flattened_range=slice(0, unflat_ten.numel()), # whole slice - ).without_data() - - return sharded_metadata - - def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: - """Uses tensors and objects metadata stored in the metadata file.""" - fs_reader = FileSystemReader(checkpoint_dir) - metadata = fs_reader.read_metadata() - - sharded_metadata = {} - for metadata_key, storage_metadata in metadata.state_dict_metadata.items(): - if not isinstance(storage_metadata, BytesStorageMetadata): - continue - sh_obj = ShardedObject.empty_from_unique_key(metadata_key) - sharded_metadata[sh_obj.unique_key] = sh_obj - - sharded_metadata.update(self.load_tensors_metadata(checkpoint_dir, metadata)) - return sharded_metadata - - def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str): - """Removes checkpoint files whose keys have the given prefix. - - Performs the following steps: - 1. checks whether there are files that start with the key_prefix - 2. loads metadata - 3. removes all entries from the metadata that start with the key_prefix - 4. resaves the new metadata and removes the old metadata - 5. removes the relevant files - """ - - assert is_torch_min_version( - "2.3.0" - ), f'torch >= 2.3.0 is required for remove_sharded_tensors' - - distckpt_files = [f for f in os.listdir(checkpoint_dir) if f.endswith("distcp")] - files_to_remove = [f for f in distckpt_files if f.startswith(key_prefix)] - - if not files_to_remove: - warnings.warn( - f'There are no files in {checkpoint_dir} that begin with "{key_prefix}".' - f' Skipping removal.' - ) - return - - fs_reader = FileSystemReader(checkpoint_dir) - original_metadata = fs_reader.read_metadata() - - new_state_dict_metadata = {} - new_planner_data = {} - new_storage_data = {} - for k in original_metadata.state_dict_metadata.keys(): - if k.startswith(key_prefix): - continue - new_state_dict_metadata[k] = original_metadata.state_dict_metadata[k] - for k in original_metadata.planner_data.keys(): - if k.startswith(key_prefix): - continue - new_planner_data[k] = original_metadata.planner_data[k] - for k in original_metadata.storage_data.keys(): - if k.fqn.startswith(key_prefix): - continue - new_storage_data[k] = original_metadata.storage_data[k] - metadata = Metadata( - state_dict_metadata=new_state_dict_metadata, - planner_data=new_planner_data, - storage_data=new_storage_data, - ) - fs_writer = FileSystemWriter(checkpoint_dir) - metadata_filename = cast(Path, fs_writer.fs.concat_path(fs_writer.path, _metadata_fn)) - tmp_path = cast( - metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.tmp") - ) - old_path = cast( - metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.bck") - ) - ## save the new metadata - with fs_writer.fs.create_stream(tmp_path, "wb") as metadata_file: - pickle.dump(metadata, metadata_file) - try: - os.fsync(metadata_file.fileno()) - except AttributeError: - os.sync() - ## move the old metadata - fs_writer.fs.rename(fs_writer.metadata_path, old_path) - try: - ## rename the new metadata - fs_writer.fs.rename(tmp_path, fs_writer.metadata_path) - - ## finally, remove the files we want to drop - for f in files_to_remove: - fs_writer.fs.rm_file(checkpoint_dir / f) - except Exception as e: - fs_writer.fs.rename(old_path, fs_writer.metadata_path) - raise e - else: - fs_writer.fs.rm_file(old_path) - - def can_handle_sharded_objects(self): - return True - - def check_backend_compatibility(self, loaded_version): - pass # TODO - - def check_version_compatibility(self, loaded_version): - pass # TODO +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Strategies using PyTorch distributed.checkpoint as an underlying format. """ +import io +import os +import pickle +import warnings +from collections import ChainMap, defaultdict +from dataclasses import dataclass +from itertools import product +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast + +import torch +from packaging.version import Version as PkgVersion +from torch.distributed import checkpoint +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharded_tensor import Shard +from torch.distributed._shard.sharded_tensor import ShardedTensor as TorchShardedTensor +from torch.distributed._shard.sharded_tensor import ShardedTensorMetadata, TensorProperties +from torch.distributed.checkpoint import ( + BytesStorageMetadata, + DefaultLoadPlanner, + DefaultSavePlanner, + FileSystemReader, + FileSystemWriter, + LoadPlan, + Metadata, + ReadItem, + SavePlan, + TensorStorageMetadata, + WriteItem, +) +from torch.distributed.checkpoint._nested_dict import FLATTEN_MAPPING, unflatten_state_dict +from torch.distributed.checkpoint._traverse import OBJ_PATH, traverse_state_dict +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.planner_helpers import _create_write_items + +from ...utils import get_torch_version, is_torch_min_version +from ..core import CheckpointingException +from ..dict_utils import nested_values +from ..mapping import ( + ShardedBase, + ShardedObject, + ShardedStateDict, + ShardedTensor, + StateDict, + is_main_replica, +) +from .async_utils import AsyncRequest +from .base import ( + AsyncSaveShardedStrategy, + LoadShardedStrategy, + StrategyAction, + register_default_strategy, +) +from .cached_metadata_filesystem_reader import CachedMetadataFileSystemReader +from .filesystem_async import FileSystemWriterAsync +from .resharding import ( + TensorReformulationMetadata, + apply_nd_flattened_tensors_reformulation, + is_nd_flattened_tensor, + nd_flattened_tensor_reformulated_global_shape, + restore_nd_flattened_tensors_formulation, +) +from .state_dict_saver import save_state_dict_async_finalize, save_state_dict_async_plan + +try: + if not torch.cuda.is_available(): + raise ImportError + from transformer_engine.pytorch.float8_tensor import Float8Tensor + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + from torch.distributed._tensor import DTensor + + HAVE_DTENSOR = True +except ImportError: + HAVE_DTENSOR = False + +_metadata_fn: str = ".metadata" + + +def register_default_torch_strategies(): + """Register default strategies related to PyT Distributed backend.""" + register_default_strategy( + StrategyAction.LOAD_SHARDED, 'torch_dist', 1, TorchDistLoadShardedStrategy() + ) + register_default_strategy( + StrategyAction.SAVE_SHARDED, 'torch_dist', 1, TorchDistSaveShardedStrategy('torch_dist', 1) + ) + + +logger = getLogger(__name__) + + +def flatten_state_dict( + state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, Dict[str, OBJ_PATH]]: + """Flattens state dict into a single level dict. + + It's a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict + which also accepts ShardedBase tensors as terminal objects + + Args: + state_dict (ShardedStateDict): state dict to be flattened + + Returns (tuple): flattened state dict and a mapping allowing to recreate the original one + + """ + flattened = {} + mappings = {} + + def flat_copy(path: OBJ_PATH, value: Any) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + traverse_state_dict(state_dict, flat_copy, lambda x: isinstance(x, (torch.Tensor, ShardedBase))) + return flattened, mappings + + +def sharded_tensor_to_torch_sharded_tensor( + sh_tens: List[ShardedTensor], + rank: Optional[int] = None, + load_legacy_1d_flatten_tensors: bool = False, +) -> TorchShardedTensor: + """Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks. + + On high-level, this function follows the logic of + torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor. + Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore) + as attributes for further restoration in `_unwrap_pyt_sharded_tensor`. + + NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor. + The only local irregularities could be introduced with a `flattened_range` attribute. + + This function handles 2 different type of ShardedTensors: + 1. Non-flat regular ShardedTensors (`not has_flattened_range`) + 2. N-D flattened ShardedTensors (`has_flattened_range`) + + (1) type are saved according to their original shape. + Type (2) however requires global shape adjustment for efficiency: + we treat [X, Y, Z] global shape tensor with local shape [x, y, z] + as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis + partitioned according to `flattened_range` slices. + This will need special handling while resharding. + + Args: + sh_tens (List[ShardedTensor]): list of sharded tensors to convert + rank (int, optional): current process rank passed to PyT ShardedTensor. + If None, assumes rank in the default pg. + load_legacy_1d_flatten_tensors (bool, optional): flag indicating if 1-D flattened tensors + should be loaded in a legacy way. Defaults to False. + + Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards. + + """ + if rank is None: + rank = torch.distributed.get_rank() + + some_sh_ten = sh_tens[0] + has_flattened_range = some_sh_ten.flattened_range is not None + + for sh_ten in sh_tens: + assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens + if not sh_ten.data.is_contiguous(): + sh_ten.data = sh_ten.data.contiguous() + + if load_legacy_1d_flatten_tensors and len(some_sh_ten.global_shape) == 1: + # Legacy 1-D flattened tensors are loaded as non-flat regular ShardedTensors + has_flattened_range = False + + local_global_offsets = {} + + prepend_axis_num = sh_tens[0].prepend_axis_num + # Determine local shards according to tensor type (see docs) + if has_flattened_range: + # Type (3) case: N-D flattened ShardedTensors + for sh_ten in sh_tens: + local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append( + sh_ten + ) + assert sh_ten.data.ndim == 1, sh_ten + sh_ten.data = sh_ten.data.view((1,) * len(sh_ten.global_shape) + (-1,)) + + # Global shape reformulation: + global_shape = nd_flattened_tensor_reformulated_global_shape(some_sh_ten) + offsets_shape = (1,) * len( + some_sh_ten.global_shape + ) # reformulated global shape has shape equal ti number of local chunks + + local_shards = [ + Shard.from_tensor_and_offsets( + sh_ten.data, + list( + sh_ten.local_chunk_offset_in_global() + (sh_ten.flattened_range.start,) + ), # additional flattened offset + rank, + ) + for sh_ten in sh_tens + ] + else: + # Type (1) case: non-flat regular ShardedTensors + for sh_ten in sh_tens: + local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten) + sh_ten.data = sh_ten.data.view( + (1,) * prepend_axis_num + sh_ten.local_shape + ) # adjust to prepended_axis_num + + global_shape = some_sh_ten.global_shape + offsets_shape = some_sh_ten.data.shape # includes prepended axes + + local_shards = [ + Shard.from_tensor_and_offsets( + sh_ten.data, list(sh_ten.global_offset), rank # simple case + ) + for sh_ten in sh_tens + ] + + # Create a ShardedTensor without invoking communication. Determine global shards + world_size = torch.distributed.get_world_size() + shard_metadata = [] + # NOTE: here we assume a regular grid of shards + for fragment_offsets in product(*map(range, some_sh_ten.axis_fragmentations)): + offset = tuple(map(lambda x: x[0] * x[1], zip(fragment_offsets, offsets_shape))) + if offset in local_global_offsets: + # local shard + placement = f"rank:{rank}/cuda" + for sh_ten in local_global_offsets[offset]: + if has_flattened_range: + assert offset == sh_ten.local_chunk_offset_in_global() + # This is not an actual offset, but an offset of the whole shard + # This is needed for a PyT Dist internal integrity check + offset = sh_ten.local_chunk_offset_in_global() + (0,) + size = (1,) * len(offsets_shape) + global_shape[-1:] + else: + size = sh_ten.data.shape + shard_metadata.append(ShardMetadata(offset, size, placement)) + + else: + # pylint: disable=line-too-long + # for shards from other ranks we provide simplistic data - this information will be discarded + # during TorchShardedTensor._init_from_local_shards_and_global_metadata call. + # Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size. + # The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS. + placement = f"rank:{(rank + 1) % world_size}/cuda" + if has_flattened_range: + offset = offset + (0,) + size = (1,) * len(offsets_shape) + global_shape[-1:] + else: + size = offsets_shape + shard_metadata.append(ShardMetadata(offset, size, placement)) + + tensor = some_sh_ten.data + sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=shard_metadata, + size=torch.Size(global_shape), + tensor_properties=TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ), + ) + pyt_sh_ten = TorchShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=None + ) + # Store MCore related data as PyTShardedTensor attribute. + # This won't be stored in the checkpoint, only for runtime purposes + pyt_sh_ten.mcore_sh_ten = sh_ten.without_data() + pyt_sh_ten.mcore_metadata = {} + if has_flattened_range: + pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape + return pyt_sh_ten + + +def mcore_to_pyt_state_dict( + state_dict: Dict[str, List[ShardedBase]], + is_loading: bool = False, + init_device: torch.device = torch.device("cpu"), + load_legacy_1d_flatten_tensors: bool = False, +) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]: + """Convert state dict with ShardedTensors and ShardedObjects + to state dict compatible with PyT Dist format. + + Operates in-place and returns the original state dict. + + Args: + state_dict (Dict[str, List[ShardedBase]]): flattened state dict, where values + are lists of either ShardedTensor or ShardedObjects. + is_loading (bool, optional): flag indicating if loading or saving. Defaults to False. + init_device (torch.device, optional): device to initialize potentially missing tensors + during loading. Defaults to 'cpu'. + + Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values + converted either into PyT ShardedTensors or io.BytesIO. + + """ + rank = torch.distributed.get_rank() + pyt_state_dict = {} + + def _mcore_to_torch_sharded_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor: + """Build a PyT ShardedTensor from given shards. + + During loading: + - if data is None, initialize it with an empty tensor (will be used to copy the data into) + - if `allow_shape_mismatch` is True, the data is initialized with zeros + prior to loading (not all parts of the tensor will be read from the checkpoint) + """ + assert all(isinstance(sh_ten, ShardedTensor) for sh_ten in sh_tens), sh_tens + for sh_ten in sh_tens: + if sh_ten.data is None: + if is_loading: + sh_ten.init_data( + init_device, + init_fn=torch.zeros if sh_ten.allow_shape_mismatch else torch.empty, + ) + else: + raise CheckpointingException(f'`data` attr is None for {sh_ten}') + else: + sh_ten.data = sh_ten.data.detach() + if sh_ten.allow_shape_mismatch and is_loading: + sh_ten.data.zero_() + + torch_sh_ten = sharded_tensor_to_torch_sharded_tensor( + sh_tens, rank, load_legacy_1d_flatten_tensors + ) + torch_sh_ten.key = sh_tens[0].key + return torch_sh_ten + + def _mcore_to_torch_sharded_object(sh_objs: List[ShardedObject]) -> io.BytesIO: + """Build io.BytesIO from given sharded objects data.""" + assert all(isinstance(sh_obj, ShardedObject) for sh_obj in sh_objs), sh_objs + serialized_data = io.BytesIO() + torch.save([sh_obj.data for sh_obj in sh_objs], serialized_data) + return serialized_data + + for k, v in state_dict.items(): + if isinstance(v[0], ShardedTensor): + v = cast(List[ShardedTensor], v) + pyt_state_dict[k] = _mcore_to_torch_sharded_tensor(v) + else: + v = cast(List[ShardedObject], v) + pyt_state_dict[k] = _mcore_to_torch_sharded_object(v) + + return pyt_state_dict + + +def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]: + """Unwrap tensor from PyT ShardedTensor instance. + + If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor) + then the tensor has additional singleton dimensions which should be squeezed. + """ + mcore_sh_ten = sh_ten.mcore_sh_ten + ret_tensors = [] + for sh in sh_ten.local_shards(): + ten = sh.tensor + if mcore_sh_ten.flattened_range is not None: + assert ten.shape[:-1] == (1,) * (len(ten.shape) - 1), ten.shape + ten = ten.view(-1) + else: + for _ in range(mcore_sh_ten.prepend_axis_num): + ten = ten.squeeze(0) + ret_tensors.append(ten) + return ret_tensors + + +def _replace_state_dict_keys_with_sharded_keys( + sharded_state_dict: ShardedStateDict, keep_only_main_replica: bool = False +) -> Tuple[Dict[str, List[ShardedBase]], FLATTEN_MAPPING, Dict[str, List[str]]]: + """Group ShardedBase objects by keys and + return mappings required for recreating the original dict.""" + flat_sd, flat_mapping = flatten_state_dict(sharded_state_dict) + rename_mapping = defaultdict(list) + new_flat_sd = defaultdict(list) + for k, sh_base in flat_sd.items(): + assert isinstance(sh_base, ShardedBase), type(sh_base) + key = sh_base.unique_key if isinstance(sh_base, ShardedObject) else sh_base.key + if is_main_replica(sh_base.replica_id) or not keep_only_main_replica: + rename_mapping[key].append(k) + new_flat_sd[key].append(sh_base) + return new_flat_sd, flat_mapping, rename_mapping + + +def _replace_sharded_keys_with_state_dict_keys( + state_dict: Dict[str, List[Union[torch.Tensor, io.BytesIO]]], + flat_mapping: FLATTEN_MAPPING, + rename_mapping: Dict[str, List[str]], +): + """Inverse of _replace_state_dict_keys_with_sharded_keys.""" + recovered_sd = {} + for k, tensors in state_dict.items(): + assert len(tensors) == len(rename_mapping[k]) + for ten, recovered_k in zip(tensors, rename_mapping[k]): + recovered_sd[recovered_k] = ten + + return unflatten_state_dict(recovered_sd, flat_mapping) + + +def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, list, Any]): + """Recursively update `x` keys, based on `keys_template`.""" + if isinstance(keys_template, dict): + assert isinstance(x, dict), type(x) + for k, v in keys_template.items(): + if not isinstance(k, str): + assert str(k) in x, (k, x.keys) + x[k] = x.pop(str(k)) + _restore_dict_types(x[k], v) + elif isinstance(keys_template, list): + assert isinstance(x, list), type(x) + for x_val, templ_val in zip(x, keys_template): + _restore_dict_types(x_val, templ_val) + + +@dataclass(frozen=True) +class MCoreSavePlan(SavePlan): + """SavePlan with MCore specific data.""" + + mcore_data: Dict[str, Dict[str, Any]] = None # Mcore related data about each tensor + + +class MCoreSavePlanner(DefaultSavePlanner): + """Differs with the default planner by saving BytesIO objects on all ranks. + + In the integration of MCore with PyT Distributed format, BytesIO objects + come from ShardedObjects, which should be treated as separate objects on each rank + (not common on all ranks). + + Also, the objects are already packed in io.BytesIO, so no need to redo it + in transform_object. + """ + + def __init__( + self, + *args, + dedup_replicated_tensors: Optional[bool] = None, + nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None, + can_run_decentralized_global_plan: bool = True, + **kwargs, + ) -> None: + # `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings + # during saving. + if get_torch_version() <= PkgVersion("2.2"): + kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors + super().__init__(*args, **kwargs) + self.nd_flattened_global_shapes = nd_flattened_global_shapes or {} + self.can_run_decentralized_global_plan = can_run_decentralized_global_plan + if can_run_decentralized_global_plan: + assert ( + not dedup_replicated_tensors + ), 'Cannot run decentralized plan with dedup_replicated_tensors=True' + assert ( + not self.flatten_state_dict + ), 'Cannot run decentralized plan with flatten_state_dict=True' + + def create_local_plan(self) -> SavePlan: + """Adds IOBytes write request on non-coordinator ranks.""" + + # NOTE: for PyT 2.4.0a0 we can't rely on `create_default_local_save_plan` because + # some alpha versions (specifically 2.4.0a0+f70bd71a48 in 24.06 NGC PyTorch container) + # add iobytes request only on coordinator ranks and some alpha versions + # (specifically 2.4.0a0+3bcc3cddb5 in 24.07 NGC PyTorch container) + # add those requests on all ranks. We inline a simplified version of this method below. + write_items = [] + for fqn, obj in self.state_dict.items(): + assert not HAVE_DTENSOR or not isinstance( + obj, DTensor + ) # translation from MCore ShardedTensors shouldn't result in DTensors + # Create write requests for tensor and bytes values. + # For MCore, these should be already non-duplicates. + write_items += _create_write_items(fqn, obj) + + self.plan = MCoreSavePlan( + items=write_items, + planner_data=self.mappings, + mcore_data={ + k: sh_ten.mcore_metadata + for k, sh_ten in self.state_dict.items() + if isinstance(sh_ten, TorchShardedTensor) + }, + ) + return self.plan + + def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]: + """Merges MCore data for all plans.""" + global_plan, metadata = super().create_global_plan(all_plans) + metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans))) + return global_plan, metadata + + def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: + """Nothing to do, just some checks. + + Args: + local_plan (SavePlan): local plan to turn to a global plan + (without interactions with other ranks) + + Returns: + SavePlan - locally transformed plan equivalent to the plan that would be + created by the coordinator + """ + assert ( + not self.flatten_state_dict + ), 'Cannot run decentralized plan with flatten_state_dict=True' + assert not local_plan.planner_data, 'Planner data should be empty with decentralized plan' + return local_plan + + def transform_object(self, write_item: WriteItem, object: Any): + """Make no transformations - bytes objects are already serialized.""" + return object + + +class MCoreLoadPlanner(DefaultLoadPlanner): + """Adds global shape validation to the default planner. + + If global shape validation can be ignored (shouldn't!), the default + load planner can be used. + """ + + def __init__( + self, *args, shapes_validation_sharded_tensors: Iterable[ShardedTensor] = (), **kwargs + ) -> None: + super().__init__(*args, **kwargs) + self.shapes_validation_sharded_tensors = shapes_validation_sharded_tensors + self._intermediate_read_item_and_target: Optional[Tuple[ReadItem, torch.Tensor]] = None + + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: + raise KeyError( + f"{sh_ten.key} from model not in state dict:" + f" {sorted(metadata.state_dict_metadata.keys())}" + ) + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + if not is_nd_flattened_tensor(sh_ten): + expected_shape = sh_ten.global_shape + else: + expected_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten) + if loaded_shape != expected_shape: + if is_nd_flattened_tensor(sh_ten) and len(sh_ten.global_shape) == 1: + # Handle legacy 1-D flattened tensors checkpoint format + # where the global shape is not stored in the metadata + expected_shape = sh_ten.global_shape + if loaded_shape == expected_shape: + continue + _msg = ( + f'Global shape mismatch for loaded ({loaded_shape})' + f' and expected ({expected_shape}) tensor' + f' for key {sh_ten.key}' + ) + raise CheckpointingException(_msg) + + def create_local_plan(self) -> LoadPlan: + """Runs additional shapes validation.""" + self._validate_global_shapes(self.metadata, self.shapes_validation_sharded_tensors) + return super().create_local_plan() + + def resolve_tensor(self, read_item: ReadItem): + """Override to add FP8 support. + + Narrowing the Float8Tensor can create incontiguous tensors and there are + no `copy` kernels for such cases. This method creates a contiguous FP8 + tensors so that the subsequent `copy_` in FileSystemReader succeeds. + Note that this requires tracking the original tensor + (as `self._intermediate_read_item_and_target` attribute) + and restoring it in `commit_tensor` method. + """ + target_tensor = super().resolve_tensor(read_item) + if ( + not target_tensor.is_contiguous() + and HAVE_TE + and isinstance(target_tensor, Float8Tensor) + ): + self._intermediate_read_item_and_target = (read_item, target_tensor) + target_tensor = Float8Tensor.make_like( + target_tensor, data=target_tensor._data.contiguous() + ) + return target_tensor + + def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: + """Restores the original FP8 tensor saved in `resolve_tensor`.""" + if self._intermediate_read_item_and_target is not None: + interm_read_item, target_tensor = self._intermediate_read_item_and_target + assert ( + interm_read_item is read_item + ), '`commit_tensor` method should be called right after `resolve_tensor`' + target_tensor.copy_(tensor) + tensor = target_tensor + self._intermediate_read_item_and_target = None + return super().commit_tensor(read_item, tensor) + + +class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy): + """Async save strategy for the PyT Distributed format. + + The idea is to translate MCore ShardedTensors into PyT ShardedTensors + and use the async-adjusted torch.distributed.checkpoint saving mechanism + provided by the FileSystemWriterAsync writer. + """ + + def __init__( + self, + backend: str, + version: int, + keep_only_main_replica: bool = True, + thread_count: int = 2, + cached_metadata: bool = False, + separation_hint: str = None, + ): + """Adds parameters specific to PyT Distributed format + Args: + backend (str): format backend string + version (int): format version + keep_only_main_replica (bool, optional): PyT Distributed has a mechanism + for deduplication, but replica_id aware deduplication is more coherent. + Default is True (recommended to keep it). + thread_count (int, optional): threads to use during saving. + Affects the number of files in the checkpoint (saving ranks * num_threads). + cached_metadata (bool, optional): Enables using cached global metadata to avoid + gathering local metadata every checkpointing invocation + separation_hint(str, optional): If provided, all tensors whose keys have this + prefix will be saved to a separate file. + """ + super().__init__(backend, version) + self.keep_only_main_replica = keep_only_main_replica + self.thread_count = thread_count + + # Cached SavePlans to skip plan in `save_state_dict_async_plan` + # cached outcome of `SavePlan.prepare_global_plan`, + # which aggregates local plans from all ranks + self.cached_central_plan: SavePlan = None + # cached outcome of `SavePlan.prepare_local_plan` describes how local state_dict is written + self.cached_local_plan: SavePlan = None + # Cached global metadata, only `coordinator` for dist-ckpt holds + # if central plans are consistent over iters + self.cached_global_metadata: Metadata = None + # This variable records if the ckpt structures are consistent + # so the following checkpoint savings reuse `cached_global_metadata` + self.validated_cache_reuse: bool = False + # The knob to enable cached metadata communication in saving + self.use_cached_ckpt_structure: bool = cached_metadata + + self.separation_hint = separation_hint + + self.validated_loaded_metadata_reuse = False + + def async_save( + self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path + ) -> AsyncRequest: + """Translates MCore ShardedTensors to PyT ShardedTensors & saves in PyT Distributed format. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to save + checkpoint_dir (Path): checkpoint directory + + Returns: None + """ + # Translate the state dict + (sharded_state_dict, flat_mapping, rename_mapping) = ( + _replace_state_dict_keys_with_sharded_keys( + sharded_state_dict, self.keep_only_main_replica + ) + ) + pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False) + # Use PyT saving mechanism + writer = FileSystemWriterAsync( + checkpoint_dir, separation_hint=self.separation_hint, thread_count=self.thread_count + ) + # This should be set differently if we run in a smaller process group than the default + coordinator = 0 + # Try twice to validate the generated `central_plan` is the same across iterations + # If so, reuse `cached_central_plan` and `cached_global_metadata` + # From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata` + # (return None) so `self.cached_global_metadata` is reused + args_cached_plans = None + loaded_all_plans = None + if self.use_cached_ckpt_structure: + loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None) + if loaded_all_plans is None: + logger.debug( + "no all_local_plans in metadata - can't verify global metadata reuse..." + ) + + args_cached_plans = ( + self.cached_central_plan, + self.cached_local_plan, + self.validated_cache_reuse, + ) + + ( + save_state_dict_ret, + self.cached_central_plan, + self.cached_local_plan, + self.validated_cache_reuse, + self.validated_loaded_metadata_reuse, + ) = save_state_dict_async_plan( + pyt_state_dict, + writer, + None, + coordinator, + planner=MCoreSavePlanner( + dedup_replicated_tensors=not self.keep_only_main_replica, flatten_state_dict=False + ), + cached_ckpt_structure=args_cached_plans, + loaded_all_plans=loaded_all_plans, + ) + rank = torch.distributed.get_rank() + if self.use_cached_ckpt_structure: + if ( + loaded_all_plans + and self.cached_global_metadata + and self.validated_loaded_metadata_reuse + ): + if coordinator == rank: + logger.debug( + f"rank: {rank}, reuse global metadata from loaded" + f" .metadata, {save_state_dict_ret[1]}" + ) + save_state_dict_ret = list(save_state_dict_ret) + save_state_dict_ret[1] = self.cached_global_metadata + + elif self.validated_cache_reuse: + logger.debug(f"rank: {rank}, cache validated") + if save_state_dict_ret[1]: # when global_metadata is not cached + self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata + # Only Coordinator rank holds cached global_metadata + # (None is returned for global_metadata) + elif coordinator == rank: + logger.debug( + f"rank: {rank}, reuse global metadata cached from previous" + f" save iteration, {save_state_dict_ret[1]}" + ) + save_state_dict_ret = list(save_state_dict_ret) + save_state_dict_ret[1] = self.cached_global_metadata + + return self._get_save_and_finalize_callbacks(writer, save_state_dict_ret) + + def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest: + save_fn_args = writer.get_save_function_and_args() + save_fn, preload_fn, save_args = save_fn_args + + def finalize_fn(): + save_state_dict_async_finalize(*save_state_dict_ret) + torch.distributed.barrier() + + return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn) + + def can_handle_sharded_objects(self): + return True + + +def get_reformulation_metadata( + sharded_state_dict: ShardedStateDict, checkpoint_dir: Path +) -> Dict[str, TensorReformulationMetadata]: + """Reads MCore data for N-D flattened tensors from checkpoint metadata during ckpt load. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to load + checkpoint_dir (Path): checkpoint directory + + Returns: + Dict[str, TensorReformulationMetadata] - dictionary that maps keys of every + N-D flattened tensor from the sharded_state_dict to its original global shape + as stored in `mcore_data` in the checkpoint. + """ + ckpt_metadata = FileSystemReader(checkpoint_dir).read_metadata() + reformulation_metadata = {} + for sh_ten in nested_values(sharded_state_dict): + if not is_nd_flattened_tensor(sh_ten): + continue + try: + ckpt_global_shape = ckpt_metadata.mcore_data[sh_ten.key][ + 'nd_reformulated_orig_global_shape' + ] + except KeyError as e: + if len(sh_ten.global_shape) == 1: + warnings.warn( + f'Legacy checkpoint format detected for 1-D flattened tensor {sh_ten}. ' + 'Skip metadata reformulation.' + ) + continue + raise CheckpointingException( + f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} ' + f'in checkpoint metadata: {ckpt_metadata.mcore_data}' + ) from e + + reformulation_metadata[sh_ten.key] = TensorReformulationMetadata( + ckpt_global_shape, ckpt_metadata.state_dict_metadata[sh_ten.key].size + ) + return reformulation_metadata + + +class TorchDistLoadShardedStrategy(LoadShardedStrategy): + """Basic load strategy for the PyT Distributed format.""" + + def __init__(self): + self.cached_global_metadata: Optional[Metadata] = None + super().__init__() + + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: + """Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict with mapping + information to instruct loading + checkpoint_dir (Path): checkpoint directory + + Returns: loaded state dict + """ + # Apply N-D tensors resharding + reformulation_metadata = get_reformulation_metadata(sharded_state_dict, checkpoint_dir) + sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation( + sharded_state_dict, reformulation_metadata + ) + + # Check if there are legacy 1-D flattened tensors in the checkpoint + has_legacy_1d_flattened_tensors = False + for sh_ten in nested_values(sharded_state_dict): + if is_nd_flattened_tensor(sh_ten) and sh_ten.key not in reformulation_metadata: + has_legacy_1d_flattened_tensors = True + break + + flexible_shape_sharded_tensors = [ + sh_ten + for sh_ten in nested_values(sharded_state_dict) + if isinstance(sh_ten, ShardedTensor) and not sh_ten.allow_shape_mismatch + ] + + orig_sharded_state_dict = sharded_state_dict + # MCore state dict to PyT Distributed compatible + (sharded_state_dict, flat_mapping, rename_mapping) = ( + _replace_state_dict_keys_with_sharded_keys(sharded_state_dict) + ) + pyt_state_dict = mcore_to_pyt_state_dict( + sharded_state_dict, True, load_legacy_1d_flatten_tensors=has_legacy_1d_flattened_tensors + ) + # Load PyT Distributed format + fsr = CachedMetadataFileSystemReader(checkpoint_dir) + checkpoint.load_state_dict( + pyt_state_dict, + fsr, + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors + ), + ) + + self.cached_global_metadata = ( + fsr.read_metadata() + ) # no storage interaction thanks to caching + + pyt_state_dict = cast( + Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict + ) + # Unwrap ShardedTensors and return to original state dict + mcore_state_dict = { + k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v) + for k, v in pyt_state_dict.items() + } + mcore_state_dict = _replace_sharded_keys_with_state_dict_keys( + mcore_state_dict, flat_mapping, rename_mapping + ) + _restore_dict_types(mcore_state_dict, orig_sharded_state_dict) + # Apply N-D tensors resharding postprocessing + mcore_state_dict = restore_nd_flattened_tensors_formulation( + mcore_state_dict, formulation_restore_data + ) + return mcore_state_dict + + def load_tensors_metadata(self, checkpoint_dir: Path, metadata: Metadata = None): + """Uses tensors metadata stored in the metadata file.""" + if metadata is None: + fs_reader = FileSystemReader(checkpoint_dir) + metadata = fs_reader.read_metadata() + + mcore_data = getattr(metadata, 'mcore_data', {}) + sharded_metadata = {} + for k, tp in metadata.state_dict_metadata.items(): + if not isinstance(tp, TensorStorageMetadata): + continue # load only tensors + + nd_orig_global_shape = mcore_data.get(k, {}).get('nd_reformulated_orig_global_shape') + if nd_orig_global_shape is None: + # Regular tensor + sharded_metadata[k] = ShardedTensor.from_rank_offsets( + k, torch.empty(tp.size, **tp.properties.__dict__, device='meta') + ).without_data() + else: + # N-D flattened tensor + unflat_ten = torch.empty( + nd_orig_global_shape, **tp.properties.__dict__, device='meta' + ) + flat_ten = unflat_ten.flatten() + sharded_metadata[k] = ShardedTensor.from_rank_offsets_flat( + k, + flat_ten, + unflat_ten.shape, + flattened_range=slice(0, unflat_ten.numel()), # whole slice + ).without_data() + + return sharded_metadata + + def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: + """Uses tensors and objects metadata stored in the metadata file.""" + fs_reader = FileSystemReader(checkpoint_dir) + metadata = fs_reader.read_metadata() + + sharded_metadata = {} + for metadata_key, storage_metadata in metadata.state_dict_metadata.items(): + if not isinstance(storage_metadata, BytesStorageMetadata): + continue + sh_obj = ShardedObject.empty_from_unique_key(metadata_key) + sharded_metadata[sh_obj.unique_key] = sh_obj + + sharded_metadata.update(self.load_tensors_metadata(checkpoint_dir, metadata)) + return sharded_metadata + + def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str): + """Removes checkpoint files whose keys have the given prefix. + + Performs the following steps: + 1. checks whether there are files that start with the key_prefix + 2. loads metadata + 3. removes all entries from the metadata that start with the key_prefix + 4. resaves the new metadata and removes the old metadata + 5. removes the relevant files + """ + + assert is_torch_min_version( + "2.3.0" + ), f'torch >= 2.3.0 is required for remove_sharded_tensors' + + distckpt_files = [f for f in os.listdir(checkpoint_dir) if f.endswith("distcp")] + files_to_remove = [f for f in distckpt_files if f.startswith(key_prefix)] + + if not files_to_remove: + warnings.warn( + f'There are no files in {checkpoint_dir} that begin with "{key_prefix}".' + f' Skipping removal.' + ) + return + + fs_reader = FileSystemReader(checkpoint_dir) + original_metadata = fs_reader.read_metadata() + + new_state_dict_metadata = {} + new_planner_data = {} + new_storage_data = {} + for k in original_metadata.state_dict_metadata.keys(): + if k.startswith(key_prefix): + continue + new_state_dict_metadata[k] = original_metadata.state_dict_metadata[k] + for k in original_metadata.planner_data.keys(): + if k.startswith(key_prefix): + continue + new_planner_data[k] = original_metadata.planner_data[k] + for k in original_metadata.storage_data.keys(): + if k.fqn.startswith(key_prefix): + continue + new_storage_data[k] = original_metadata.storage_data[k] + metadata = Metadata( + state_dict_metadata=new_state_dict_metadata, + planner_data=new_planner_data, + storage_data=new_storage_data, + ) + fs_writer = FileSystemWriter(checkpoint_dir) + metadata_filename = cast(Path, fs_writer.fs.concat_path(fs_writer.path, _metadata_fn)) + tmp_path = cast( + metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.tmp") + ) + old_path = cast( + metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.bck") + ) + ## save the new metadata + with fs_writer.fs.create_stream(tmp_path, "wb") as metadata_file: + pickle.dump(metadata, metadata_file) + try: + os.fsync(metadata_file.fileno()) + except AttributeError: + os.sync() + ## move the old metadata + fs_writer.fs.rename(fs_writer.metadata_path, old_path) + try: + ## rename the new metadata + fs_writer.fs.rename(tmp_path, fs_writer.metadata_path) + + ## finally, remove the files we want to drop + for f in files_to_remove: + fs_writer.fs.rm_file(checkpoint_dir / f) + except Exception as e: + fs_writer.fs.rename(old_path, fs_writer.metadata_path) + raise e + else: + fs_writer.fs.rm_file(old_path) + + def can_handle_sharded_objects(self): + return True + + def check_backend_compatibility(self, loaded_version): + pass # TODO + + def check_version_compatibility(self, loaded_version): + pass # TODO diff --git a/megatron/core/dist_checkpointing/strategies/two_stage.py b/megatron/core/dist_checkpointing/strategies/two_stage.py index 72e60bc79b9792976f9dd4b84ff9e714679a36df..b8a5094d43d723b47262d0101a6b52f7b825b275 100644 --- a/megatron/core/dist_checkpointing/strategies/two_stage.py +++ b/megatron/core/dist_checkpointing/strategies/two_stage.py @@ -1,254 +1,268 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" 2-stage checkpoint loading. """ -import os -import time -from collections import defaultdict -from dataclasses import dataclass -from functools import partial, wraps -from itertools import chain -from logging import DEBUG, INFO, StreamHandler, getLogger -from operator import attrgetter, itemgetter -from pathlib import Path -from typing import Iterable, List, NamedTuple, Optional, Tuple, Union - -import torch - -from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values -from ..mapping import ShardedStateDict, ShardedTensor, StateDict -from .base import LoadShardedStrategy -from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array -from .zarr import flatten_range, load_zarr_based_sharded_metadata - -_import_trigger = None - - -timers = defaultdict(list) - -logger = getLogger(__name__) - - -def timed(verbose=True): - def timed_dec(fn): - name = fn.__name__ - - @wraps(fn) - def wrapped(*args, **kwargs): - if verbose: - logger.debug(f'{name} init') - start = time.time() - ret = fn(*args, **kwargs) - took = time.time() - start - if verbose: - logger.debug(f'{name} took {took}s') - timers[name].append(took) - return ret - - return wrapped - - return timed_dec - - -@dataclass -class _ShardedTensorMetadata: - global_rank: int - sharded_tensor_no_data: ShardedTensor - dist_group_rank: Tuple[int] # id of distributed group - dist_group_ranks: Tuple[int] # id of distributed group - data_size: Optional[int] = None # bytes - - -def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor): - return (sharded_tensor.key, sharded_tensor.global_offset) - - -class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): - """Loads one checkpoint replica from storage and broadcasts to other nodes. - - This strategy loads checkpoint from storage on minimal set of nodes - and distributes the checkpoint to other nodes with torch.distributed. - Loading is performed with tensorstore. - - Steps: - 0. (optional) create Gloo distributed groups - 1. Exchange ShardedTensors metadata between all nodes - 2. Align needed tensors within DP groups - 3. For each globally unique tensor: - 3.a) on one of the ranks load it from storage to CPU and move to CUDA - 3.b) allocate CUDA tensor on other ranks - 3.c) broadcast within DP group - 3.d) copy tensor content to the model param location - 3.e) free tensor buffers from a) and b) - - Notes: - 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs - 2. There is a lot of overlap potential between all three steps done for each tensor: - 2.a) loading from storage to numpy - 2.b) moving CPU tensors to CUDA - 2.c) broadcast - """ - - def __init__(self, data_parallel_group, cpu_transfer=True): - super().__init__() - - self.cpu_transfer = cpu_transfer - self.data_parallel_group_orig = data_parallel_group - self.data_parallel_group = None if cpu_transfer else data_parallel_group - self.dp_group_ranks = tuple( - sorted(torch.distributed.get_process_group_ranks(data_parallel_group)) - ) - self.dp_group_rank = torch.distributed.get_rank(self.data_parallel_group_orig) - self.global_rank = torch.distributed.get_rank() - - def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - self.maybe_init_gloo_group() - all_tensors_sorted = self._build_load_plan(sharded_state_dict) - self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir) - # TODO: fix hang in summarize_load_times - # self.summarize_load_times() - return sharded_state_dict - - def summarize_load_times(self): - torch.distributed.barrier() - logger.info('Checkpoint loading finished. Summary:') - # TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs - for key, times in sorted(timers.items()): - times_sum = sum(times) - max_times = torch.tensor([times_sum], device='cuda') - avg_times = torch.tensor([times_sum], device='cuda') - torch.distributed.all_reduce(max_times, op=torch.distributed.ReduceOp.MAX) - torch.distributed.all_reduce(avg_times, op=torch.distributed.ReduceOp.SUM) - avg_times /= torch.distributed.get_world_size() - if torch.distributed.get_rank() == 0: - logger.info(f'{key}: max {max_times[0]}, avg {avg_times[0]}') - - @timed(verbose=False) - def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata): - logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init') - ret = _load_from_array( - ten_meta.sharded_tensor_no_data, - checkpoint_dir, - load_directly_on_device=False, - apply_flattened_range=False, - ) - logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) DONE') - return ret - - @timed() - def maybe_init_gloo_group(self): - if not self.cpu_transfer: - return - all_groups = [None] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(all_groups, self.dp_group_ranks) - all_groups = set(tuple(sorted(gr)) for gr in all_groups) - for group_ranks in sorted(all_groups): - gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo') - if self.global_rank in group_ranks: - self.data_parallel_group = gloo_pg - assert self.dp_group_rank == torch.distributed.get_rank(self.data_parallel_group) - - def check_backend_compatibility(self, loaded_version): - pass # TODO - - def check_version_compatibility(self, loaded_version): - pass # TODO - - @timed() - def _build_load_plan( - self, sharded_state_dict: ShardedStateDict - ) -> List[_ShardedTensorMetadata]: - local_meta = [ - _ShardedTensorMetadata( - self.global_rank, - sharded_ten.without_data(), - self.dp_group_rank, - self.dp_group_ranks, - ) - for sharded_ten in nested_values(sharded_state_dict) - ] - all_meta = [None] * torch.distributed.get_world_size(group=self.data_parallel_group) - torch.distributed.all_gather_object(all_meta, local_meta, group=self.data_parallel_group) - all_meta = list(chain.from_iterable(all_meta)) - all_tensors_sorted = self.deduplicate_chunks(all_meta) - return all_tensors_sorted - - @timed() - def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]): - """Group tensors by chunk and then pick the tensor with the lowest rank. - - NOTE: with proper loading overlap, loading from randomized ranks - (instead of the smallest one) could be beneficial here. - """ - ten_metas = map_reduce( - ten_metas, - key_fn=lambda meta: sharded_tensor_chunk_id(meta.sharded_tensor_no_data), - reduce_fn=partial(min, key=attrgetter('dist_group_rank')), - ) - all_metas_sorted = list(map(itemgetter(1), sorted(ten_metas.items()))) - return all_metas_sorted - - @timed() - def _exchange_loaded_tensors( - self, ten_metas: List[_ShardedTensorMetadata], sharded_state_dict, checkpoint_dir - ): - logger.debug(f'_exchange_loaded_tensors, num ten_metas: {len(ten_metas)}') - for ten_meta in ten_metas: - - src_rank = torch.distributed.get_global_rank( - self.data_parallel_group, ten_meta.dist_group_rank - ) - - if self.dp_group_rank == ten_meta.dist_group_rank: - exchange_tensor = self.load_tensor_from_storage(checkpoint_dir, ten_meta) - if not self.cpu_transfer: - exchange_tensor = exchange_tensor.cuda() - else: - # TODO: for non-flattened ranges we could reuse the buffer from the start here - exchange_tensor = torch.empty( - ten_meta.sharded_tensor_no_data.local_shape, - device='cpu' if self.cpu_transfer else 'cuda', - dtype=ten_meta.sharded_tensor_no_data.dtype, - ) - - logger.debug( - f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})' - ) - torch.distributed.broadcast( - exchange_tensor, group=self.data_parallel_group, src=src_rank - ) - self._distribute_data_to_state_dict(ten_meta, exchange_tensor, sharded_state_dict) - logger.debug(f'exchange {ten_meta.sharded_tensor_no_data.key} done') - - # free buffer memory - exchange_tensor = None - - @timed(verbose=False) - def _distribute_data_to_state_dict( - self, - ten_meta: _ShardedTensorMetadata, - loaded_ten: torch.Tensor, - sharded_state_dict: ShardedStateDict, - ): - tensor_key = sharded_tensor_chunk_id(ten_meta.sharded_tensor_no_data) - - def _fill_in_data(t: Union[ShardedTensor, torch.Tensor]): - if not isinstance(t, ShardedTensor) or sharded_tensor_chunk_id(t) != tensor_key: - # already filled-in or key not matching - return t - sharded_tensor: ShardedTensor = t - x = loaded_ten - if sharded_tensor.flattened_range is not None: - x = flatten_range(sharded_tensor, x) - - # Reuse existing buffer - sharded_tensor.data.data.copy_(x) - return sharded_tensor.data - - dict_list_map_inplace(_fill_in_data, sharded_state_dict) - - def load_tensors_metadata(self, checkpoint_dir: Path): - def get_ts_shape_dtype(path): - arr = open_ts_array(path) - return arr.shape, arr.dtype.numpy_dtype - - return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype) +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" 2-stage checkpoint loading. """ +import time +from collections import defaultdict +from dataclasses import dataclass +from functools import partial, wraps +from itertools import chain +from logging import getLogger +from operator import attrgetter, itemgetter +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch + +from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values +from ..mapping import ShardedStateDict, ShardedTensor +from .base import LoadShardedStrategy +from .tensorstore import _load_from_array, open_ts_array +from .zarr import flatten_range, load_zarr_based_sharded_metadata + +_import_trigger = None + + +timers = defaultdict(list) + +logger = getLogger(__name__) +logger.warning( + 'megatron.core.dist_checkpointing.two_stage module is deprecated' + ' and will be removed in Megatron-Core v0.12. Please use' + ' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.' +) + + +def timed(verbose=True): + """Timing decorator.""" + + def timed_dec(fn): + name = fn.__name__ + + @wraps(fn) + def wrapped(*args, **kwargs): + if verbose: + logger.debug(f'{name} init') + start = time.time() + ret = fn(*args, **kwargs) + took = time.time() - start + if verbose: + logger.debug(f'{name} took {took}s') + timers[name].append(took) + return ret + + return wrapped + + return timed_dec + + +@dataclass +class _ShardedTensorMetadata: + global_rank: int + sharded_tensor_no_data: ShardedTensor + dist_group_rank: Tuple[int] # id of distributed group + dist_group_ranks: Tuple[int] # id of distributed group + data_size: Optional[int] = None # bytes + + +def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor): + """Id of a sharded tensor.""" + return (sharded_tensor.key, sharded_tensor.global_offset) + + +class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): + """Loads one checkpoint replica from storage and broadcasts to other nodes. + + This strategy loads checkpoint from storage on minimal set of nodes + and distributes the checkpoint to other nodes with torch.distributed. + Loading is performed with tensorstore. + + Steps: + 0. (optional) create Gloo distributed groups + 1. Exchange ShardedTensors metadata between all nodes + 2. Align needed tensors within DP groups + 3. For each globally unique tensor: + 3.a) on one of the ranks load it from storage to CPU and move to CUDA + 3.b) allocate CUDA tensor on other ranks + 3.c) broadcast within DP group + 3.d) copy tensor content to the model param location + 3.e) free tensor buffers from a) and b) + + Notes: + 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs + 2. There is a lot of overlap potential between all three steps done for each tensor: + 2.a) loading from storage to numpy + 2.b) moving CPU tensors to CUDA + 2.c) broadcast + """ + + def __init__(self, data_parallel_group, cpu_transfer=True): + super().__init__() + + self.cpu_transfer = cpu_transfer + self.data_parallel_group_orig = data_parallel_group + self.data_parallel_group = None if cpu_transfer else data_parallel_group + self.dp_group_ranks = tuple( + sorted(torch.distributed.get_process_group_ranks(data_parallel_group)) + ) + self.dp_group_rank = torch.distributed.get_rank(self.data_parallel_group_orig) + self.global_rank = torch.distributed.get_rank() + + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Main load method.""" + self.maybe_init_gloo_group() + all_tensors_sorted = self._build_load_plan(sharded_state_dict) + self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir) + # TODO: fix hang in summarize_load_times + # self.summarize_load_times() + return sharded_state_dict + + def summarize_load_times(self): + """Summarize load times.""" + torch.distributed.barrier() + logger.info('Checkpoint loading finished. Summary:') + # TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs + for key, times in sorted(timers.items()): + times_sum = sum(times) + max_times = torch.tensor([times_sum], device='cuda') + avg_times = torch.tensor([times_sum], device='cuda') + torch.distributed.all_reduce(max_times, op=torch.distributed.ReduceOp.MAX) + torch.distributed.all_reduce(avg_times, op=torch.distributed.ReduceOp.SUM) + avg_times /= torch.distributed.get_world_size() + if torch.distributed.get_rank() == 0: + logger.info(f'{key}: max {max_times[0]}, avg {avg_times[0]}') + + @timed(verbose=False) + def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata): + """Load tensor from storage.""" + logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init') + ret = _load_from_array( + ten_meta.sharded_tensor_no_data, + checkpoint_dir, + load_directly_on_device=False, + apply_flattened_range=False, + ) + logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) DONE') + return ret + + @timed() + def maybe_init_gloo_group(self): + """Create Gloo groups.""" + if not self.cpu_transfer: + return + all_groups = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_groups, self.dp_group_ranks) + all_groups = set(tuple(sorted(gr)) for gr in all_groups) + for group_ranks in sorted(all_groups): + # "two_stage" module will be deprecated, so not replace new_group() + # with ...parallel_state.create_group() func setting group_desc here. + gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo') + if self.global_rank in group_ranks: + self.data_parallel_group = gloo_pg + assert self.dp_group_rank == torch.distributed.get_rank(self.data_parallel_group) + + def check_backend_compatibility(self, loaded_version): + pass # TODO + + def check_version_compatibility(self, loaded_version): + pass # TODO + + @timed() + def _build_load_plan( + self, sharded_state_dict: ShardedStateDict + ) -> List[_ShardedTensorMetadata]: + local_meta = [ + _ShardedTensorMetadata( + self.global_rank, + sharded_ten.without_data(), + self.dp_group_rank, + self.dp_group_ranks, + ) + for sharded_ten in nested_values(sharded_state_dict) + ] + all_meta = [None] * torch.distributed.get_world_size(group=self.data_parallel_group) + torch.distributed.all_gather_object(all_meta, local_meta, group=self.data_parallel_group) + all_meta = list(chain.from_iterable(all_meta)) + all_tensors_sorted = self.deduplicate_chunks(all_meta) + return all_tensors_sorted + + @timed() + def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]): + """Group tensors by chunk and then pick the tensor with the lowest rank. + + NOTE: with proper loading overlap, loading from randomized ranks + (instead of the smallest one) could be beneficial here. + """ + ten_metas = map_reduce( + ten_metas, + key_fn=lambda meta: sharded_tensor_chunk_id(meta.sharded_tensor_no_data), + reduce_fn=partial(min, key=attrgetter('dist_group_rank')), + ) + all_metas_sorted = list(map(itemgetter(1), sorted(ten_metas.items()))) + return all_metas_sorted + + @timed() + def _exchange_loaded_tensors( + self, ten_metas: List[_ShardedTensorMetadata], sharded_state_dict, checkpoint_dir + ): + logger.debug(f'_exchange_loaded_tensors, num ten_metas: {len(ten_metas)}') + for ten_meta in ten_metas: + + src_rank = torch.distributed.get_global_rank( + self.data_parallel_group, ten_meta.dist_group_rank + ) + + if self.dp_group_rank == ten_meta.dist_group_rank: + exchange_tensor = self.load_tensor_from_storage(checkpoint_dir, ten_meta) + if not self.cpu_transfer: + exchange_tensor = exchange_tensor.cuda() + else: + # TODO: for non-flattened ranges we could reuse the buffer from the start here + exchange_tensor = torch.empty( + ten_meta.sharded_tensor_no_data.local_shape, + device='cpu' if self.cpu_transfer else 'cuda', + dtype=ten_meta.sharded_tensor_no_data.dtype, + ) + + logger.debug( + f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}\ +({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})' + ) + torch.distributed.broadcast( + exchange_tensor, group=self.data_parallel_group, src=src_rank + ) + self._distribute_data_to_state_dict(ten_meta, exchange_tensor, sharded_state_dict) + logger.debug(f'exchange {ten_meta.sharded_tensor_no_data.key} done') + + # free buffer memory + exchange_tensor = None + + @timed(verbose=False) + def _distribute_data_to_state_dict( + self, + ten_meta: _ShardedTensorMetadata, + loaded_ten: torch.Tensor, + sharded_state_dict: ShardedStateDict, + ): + tensor_key = sharded_tensor_chunk_id(ten_meta.sharded_tensor_no_data) + + def _fill_in_data(t: Union[ShardedTensor, torch.Tensor]): + if not isinstance(t, ShardedTensor) or sharded_tensor_chunk_id(t) != tensor_key: + # already filled-in or key not matching + return t + sharded_tensor: ShardedTensor = t + x = loaded_ten + if sharded_tensor.flattened_range is not None: + x = flatten_range(sharded_tensor, x) + + # Reuse existing buffer + sharded_tensor.data.data.copy_(x) + return sharded_tensor.data + + dict_list_map_inplace(_fill_in_data, sharded_state_dict) + + def load_tensors_metadata(self, checkpoint_dir: Path): + def get_ts_shape_dtype(path): + arr = open_ts_array(path) + return arr.shape, arr.dtype.numpy_dtype + + return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype) diff --git a/megatron/core/dist_checkpointing/tensor_aware_state_dict.py b/megatron/core/dist_checkpointing/tensor_aware_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3d11b9d4e3ef1935d00adcfaeee735ebed4996 --- /dev/null +++ b/megatron/core/dist_checkpointing/tensor_aware_state_dict.py @@ -0,0 +1,347 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Utilities for transforming state_dict, including a tensor-aware implementation.""" + +import logging +from dataclasses import dataclass +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple + +import torch +from nvidia_resiliency_ext.checkpointing.local.base_state_dict import TensorAwareStateDict + +from .dict_utils import dict_list_map_inplace, dict_list_map_outplace, merge, nested_values +from .exchange_utils import ( + ShardDistribution, + determine_main_replica_uniform_distribution, + exchange_by_distribution, +) +from .mapping import ShardedObject, ShardedStateDict, ShardedTensor, StateDict, apply_factory_merges +from .state_dict_utils import load_preprocess, save_preprocess +from .utils import ( + _sharded_object_id, + _sharded_tensor_shard_id, + debug_time, + extract_sharded_base, + zip_strict, +) +from .validation import determine_global_metadata, validate_sharding_integrity + +logger = logging.getLogger(__name__) + + +@dataclass +class MCoreTensorAwareStateDict(TensorAwareStateDict): + """ + MCore-specific class defining the interface between the MCore state dict and checkpoint manager. + + This class distinguishes between raw objects, the common state dict, and sharded state dicts + (tensor parts). It also handles optional metadata needed for fully parallel save/load. + """ + + common: StateDict + sharded_state_dict: ShardedStateDict + _is_hollow: bool = False + + @staticmethod + def _validate_params(algo): + if algo != 'atomic' and algo != 'fully_parallel': + raise NotImplementedError( + 'Only "atomic" and "fully_parallel" sharding algorithms are supported.' + ) + + @staticmethod + def _get_distribution( + fully_parallel, sharded_part, parallelization_group, cached_distribution=None + ): + if fully_parallel: + if cached_distribution is None: + distribution = determine_main_replica_uniform_distribution( + sharded_part, parallelization_group, True + ) + logger.debug(f'MCore_TASD._get_distribution calculated distribution') + else: + distribution = cached_distribution + logger.debug(f'MCore_TASD._get_distribution used cache') + else: + distribution = (None, None, None, None) + logger.debug(f'MCore_TASD._get_distribution returned empty distribution') + return distribution + + @staticmethod + def _remove_redundant_data( + fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group + ): + if fully_parallel: + for sh_base in nested_values(sharded_part): + # TODO remove redundant objects as well + if isinstance(sh_base, ShardedTensor): + shard_id = _sharded_tensor_shard_id(sh_base) + if shard_to_saving_rank[shard_id] != torch.distributed.get_rank( + group=parallelization_group + ): + sh_base.data = None + + @classmethod + @debug_time("from_state_dict", logger) + def from_state_dict( + cls, + sharded_state_dict: ShardedStateDict, + algo: str = 'fully_parallel', + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + cached_metadata: ShardDistribution = None, + ) -> Tuple[TensorAwareStateDict, ShardDistribution]: + """ + Constructs a TensorAwareStateDict from a sharded state dictionary. + + This method preprocesses the input `sharded_state_dict`, validates parameters, + and extracts the necessary data to create an instance of `MCoreTensorAwareStateDict`. + + Args: + sharded_state_dict: The input sharded state dictionary to be converted. + algo (str, optional): Initialization algorithm. Defaults to 'fully_parallel'. + - 'fully_parallel' enables fully parallel initialization. + parallelization_group (Optional): A distributed process group for parallelization. + cached_metadata (Optional): Precomputed metadata from previous saves. + - Reuses data that doesn't need recalculation, optimizing the creation process. + + Returns: + TensorAwareStateDict: An instance initialized with the provided sharded state dictionary + and optional cached metadata. + - The metadata is stored in memory to speed up future saves. + """ + with debug_time("_get_distribution", logger): + cls._validate_params(algo) + fully_parallel = algo == 'fully_parallel' + sharded_part, common_state_dict = save_preprocess( + sharded_state_dict, cached_metadata is None + ) + cacheable_distribution = cls._get_distribution( + fully_parallel, sharded_part, parallelization_group, cached_metadata + ) + if cacheable_distribution is not None: + shard_to_saving_rank, _, _, _ = cacheable_distribution + cls._remove_redundant_data( + fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group + ) + + return ( + MCoreTensorAwareStateDict(common=common_state_dict, sharded_state_dict=sharded_part), + cacheable_distribution, + ) + + @property + def is_hollow(self): + """ + True iff tensors had been extracted and have not been inserted back yet. + """ + return self._is_hollow + + @property + def _sharded_tensors(self): + # Three possible states for sharded_tensor: + # 1. sharded_tensor with data (.data = tensor) + # 2. sharded_tensor hollow (.data = None, .orig_device = orig_device) + # 3. removed sharded_tensor (.data = None, no device information) + # TODO: Consider simplifying by removing the entire sharded_tensor instead of just the data + if self.is_hollow: + for sh_base in nested_values(self.sharded_state_dict): + # FIXME: Hacky way to store the original device of the popped tensor + if isinstance(sh_base, ShardedTensor) and hasattr(sh_base, 'orig_device'): + yield sh_base + else: + for sh_base in nested_values(self.sharded_state_dict): + if isinstance(sh_base, ShardedTensor) and sh_base.data is not None: + yield sh_base + + @property + def tensors(self) -> Iterator[torch.Tensor]: + """ + Get the tensor data from the state dict. + """ + assert not self.is_hollow # TODO raise exception + return map(lambda sh_ten: sh_ten.data, self._sharded_tensors) + + @property + def common_state_dict(self) -> Dict: + """ + Get the common state dict from the state dict. + """ + return self.common + + def pop_tensors(self) -> List[torch.Tensor]: + """ + Extracts the tensor data from the wrapped state dict, preserving metadata. + + Replaces the tensor data in sharded_tensors with device type of extracted tensors. + After this operation, the state dictionary is "hollow", containing no tensor data. + Further calls to `pop_tensor` will raise an error. + + @return List of extracted tensors + """ + assert not self.is_hollow # TODO raise exception + result = [] + for sh_ten in self._sharded_tensors: + result.append(sh_ten.data) + # FIXME: Hacky way to store the original device, which is not included in the metadata + setattr(sh_ten, 'orig_device', sh_ten.data.device.type) + sh_ten.data = None + self._is_hollow = True + return result + + def insert_tensors(self, tensor_data: Iterable[torch.Tensor]): + """ + Reverse of `pop_tensors`. Replaces device type in sharded_tensors with actual values + Value of `self` is considered to be the same after: + ``` + self.insert_tensors(self.pop_tensors()) + ``` + """ + assert self.is_hollow # TODO raise exception + for sh_ten, ten in zip_strict(self._sharded_tensors, tensor_data): + # FIXME: Hacky way to store the original device + if sh_ten.orig_device == ten.device.type: + delattr(sh_ten, 'orig_device') + # Tensor might be on non-original device + sh_ten.data = ten + self._is_hollow = False + + def init_tensors(self): + """ + Initializes empty tensors with the same properties as the original tensors. + + This function should only be called after the original tensors have been popped. + It ensures that the newly created empty tensors match the shape, + dtype, and device of the originals, but contain no data. + """ + assert self.is_hollow # TODO raise exception + for sh_ten in self._sharded_tensors: + # Hacky way to retrieve the original device + sh_ten.init_data(sh_ten.orig_device) + delattr(sh_ten, 'orig_device') + self._is_hollow = False + + def copy_tensors_to_cpu(self, non_blocking=False): + """ + Stores CPU copies of tensors in the state_dict, replacing the originals, + but without destroying them. + The original devices are remembered for restoration with restore_tensor_device(). + Using non_blocking=True allows for asynchronous copying. + """ + assert not self.is_hollow # TODO raise exception + for sh_ten in self._sharded_tensors: + if sh_ten.data.device.type == 'cpu': + # Skip cloning if it's already confirmed to be a copy + if not hasattr(sh_ten, 'orig_device'): + sh_ten.data = sh_ten.data.clone() + else: + # FIXME: Hacky way to store the original device + if not hasattr(sh_ten, 'orig_device'): + setattr(sh_ten, 'orig_device', sh_ten.data.device.type) + sh_ten.data = sh_ten.data.detach().to("cpu", non_blocking=non_blocking) + + def restore_tensor_device(self, non_blocking=True): + """ + Restores all tensors to their original devices, if a move is required. + Using non_blocking=True allows for asynchronous copying. + """ + assert not self.is_hollow # TODO raise exception + for sh_ten in self._sharded_tensors: + # FIXME: Hacky way to store the original device + if hasattr(sh_ten, 'orig_device'): + sh_ten.data = sh_ten.data.to(sh_ten.orig_device, non_blocking=non_blocking) + delattr(sh_ten, 'orig_device') + + def _insert_sharded_data( + self, fully_parallel, sharded_part, parallelization_group, exchange_algo + ): + loaded_tensors = {} + for sh_ten in self._sharded_tensors: + loaded_tensors[_sharded_tensor_shard_id(sh_ten)] = sh_ten.data + if fully_parallel: + with debug_time("_get_distribution", logger): + distribution = self._get_distribution( + fully_parallel, sharded_part, parallelization_group + ) + if distribution is not None: + unloaded_shards = {} + for sh_base in nested_values(sharded_part): + # TODO retrieve redundant ShardedObjects once removed in _remove_redundant_data + if isinstance(sh_base, ShardedTensor): + shard_id = _sharded_tensor_shard_id(sh_base) + if shard_id not in loaded_tensors: + unloaded_shards[shard_id] = sh_base + + with debug_time("exchange_by_distribution", logger): + loaded_tensors = exchange_by_distribution( + loaded_tensors, + unloaded_shards, + distribution, + parallelization_group, + exchange_algo, + ) + torch.cuda.synchronize() + loaded_objects = {} + for sh_base in nested_values(self.sharded_state_dict): + if not isinstance(sh_base, ShardedTensor): + assert isinstance(sh_base, ShardedObject) + loaded_objects[_sharded_object_id(sh_base)] = sh_base.data + + def load_sharded_base(x: Any): + if isinstance(x, ShardedTensor): + shard_id = _sharded_tensor_shard_id(x) + assert shard_id in loaded_tensors, (x, shard_id, loaded_tensors.keys()) + x = loaded_tensors[shard_id] + if isinstance(x, ShardedObject): + object_id = _sharded_object_id(x) + assert object_id in loaded_objects, (x, object_id, loaded_objects.keys()) + x = loaded_objects[object_id] + return x + + dict_list_map_inplace(load_sharded_base, sharded_part) + + @debug_time("to_state_dict", logger) + def to_state_dict( + self, + sharded_state_dict: ShardedStateDict, + algo: str = 'atomic', + exchange_algo: str = 'broadcast', + validate_access_integrity: bool = True, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + ): + """ + Convert tensor-aware dict back to the original state_dict + """ + with debug_time("load_preprocess_and_state_dict_manipulations", logger): + assert not self.is_hollow # TODO raise exception + self._validate_params(algo) + fully_parallel = algo == 'fully_parallel' + + # __adding__ common part + recreated_state_dict = dict_list_map_outplace(lambda x: x, self.common) + + if not sharded_state_dict: + return recreated_state_dict + # TODO validate self.sharded_state_dict"] and sharded_state_dict are compatible + + sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( + sharded_state_dict + ) + # __adding__ nonpersistent part + merge(recreated_state_dict, nonpersistent_state_dict) + + sharded_part, _ = extract_sharded_base(sharded_state_dict) + + if validate_access_integrity: + with debug_time("validate_sharding_integrity", logger): + validate_sharding_integrity(determine_global_metadata(sharded_part)[1]) + + # load sharded tensors and sharded objects to sharded_part + with debug_time("_insert_sharded_data", logger): + self._insert_sharded_data( + fully_parallel, sharded_part, parallelization_group, exchange_algo + ) + with debug_time("apply_factory_merges", logger): + sharded_part = apply_factory_merges(sharded_part, sh_ten_factories) + # __adding__ sharded_part + merge(recreated_state_dict, sharded_part) + return recreated_state_dict diff --git a/megatron/core/dist_checkpointing/utils.py b/megatron/core/dist_checkpointing/utils.py index 9186e4790aab1803d23f63023711e7d5ece6ac80..815c9502657a49e91b8e6bd88408e269d132c3cb 100644 --- a/megatron/core/dist_checkpointing/utils.py +++ b/megatron/core/dist_checkpointing/utils.py @@ -1,219 +1,319 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Helpers for manipulating sharded tensors and sharded state dicts. """ - -from typing import Dict, Optional, Tuple - -from .dict_utils import dict_list_map_inplace, extract_matching_values -from .mapping import ( - LocalNonpersistentObject, - ShardedBase, - ShardedObject, - ShardedStateDict, - ShardedTensor, - ShardedTensorFactory, - StateDict, -) - -# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor -# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple) -_ShardId = Tuple[str, tuple, Optional[tuple]] - - -def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId: - """Unique id of the sharded tensor data. - - Should yield the same value for same data replicated on different ranks. - - Args: - sharded_tensor (ShardedTensor): sharded tensor representing the data shard - - Returns (tuple): unique id of a data shard - """ - f_range = sharded_tensor.flattened_range - return ( - sharded_tensor.key, - sharded_tensor.global_offset, - None if f_range is None else (f_range.start, f_range.stop), - ) - - -def _sharded_object_id(sharded_object: ShardedObject) -> _ShardId: - """Unique id of the sharded object data. - - Should yield the same value for same data replicated on different ranks. - - Args: - sharded_object (ShardedObject): sharded object representing the data shard - - Returns (tuple): unique id of a data shard - """ - return (sharded_object.key, sharded_object.global_offset, sharded_object.global_shape) - - -def extract_sharded_tensors( - sharded_state_dict: ShardedStateDict, -) -> Tuple[ShardedStateDict, StateDict]: - """Extract a dict consisting of only ShardedTensor objects - from a given state dict with any objects. - - Args: - sharded_state_dict: state dict possibly containing ShardedTensor objects - - Returns: - Tuple[ShardedStateDict, StateDict]: tuple of: - - state dict with all ShardedTensor (keeping the original state dict structure) - - state dict with all objects other than ShardedTensor - (keeping the original state dict structure) - """ - return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor)) - - -def extract_sharded_tensors_and_factories( - sharded_state_dict: ShardedStateDict, -) -> Tuple[ShardedStateDict, StateDict]: - """Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects - from a given state dict with any objects. - - Args: - sharded_state_dict: - state dict possibly containing ShardedTensor and ShardedTensorFactory objects - - Returns: - Tuple[ShardedStateDict, StateDict]: tuple of: - - state dict with all ShardedTensor and ShardedTensorFactory - (keeping the original state dict structure) - - state dict with all other objects (keeping the original state dict structure) - """ - return extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory)) - ) - - -def extract_sharded_tensors_or_nonpersistent( - sharded_state_dict: ShardedStateDict, -) -> Tuple[ShardedStateDict, StateDict]: - """Extract a dict consisting of only ShardedTensor, ShardedTensorFactory - and LocalNonpersistentObject objects from a given state dict with any objects. - - Args: - sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory - and LocalNonpersistentObject objects - - Returns: - Tuple[ShardedStateDict, StateDict]: tuple of: - - state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject - (keeping the original state dict structure) - - state dict with all other objects (keeping the original state dict structure) - """ - return extract_matching_values( - sharded_state_dict, - lambda v: isinstance(v, (ShardedTensor, LocalNonpersistentObject, ShardedTensorFactory)), - ) - - -def extract_sharded_base( - sharded_state_dict: ShardedStateDict, -) -> Tuple[ShardedStateDict, StateDict]: - """Extract a dict consisting of only ShardedBase from a given state dict with any objects. - - Args: - sharded_state_dict: state dict possibly containing ShardedBase objects - - Returns: - Tuple[ShardedStateDict, StateDict]: tuple of: - - state dict with all ShardedBase objects (keeping the original state dict structure) - - state dict with all other objects (keeping the original state dict structure) - """ - return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedBase)) - - -def extract_nonpersistent( - sharded_state_dict: ShardedStateDict, -) -> Tuple[ShardedStateDict, StateDict]: - """Extract a dict consisting of only LocalNonpersistentObjects from a given state dict. - - Args: - sharded_state_dict: state dict possibly containing LocalNonpersistentObjects - - Returns: - Tuple[ShardedStateDict, StateDict]: tuple of: - - state dict with all LocalNonpersistentObjects - (keeping the original state dict structure) - - state dict with all other objects (keeping the original state dict structure) - """ - - return extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, LocalNonpersistentObject) - ) - - -def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str): - """Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict - prefix (str): prefix to be prepended - - Returns: - None: state dict is modified in-place - """ - - def add_prefix(t): - if isinstance(t, ShardedBase): - t.key = f'{prefix}{t.key}' - return t - - dict_list_map_inplace(add_prefix, sharded_state_dict) - - -def replace_prefix_for_sharding( - sharded_state_dict: ShardedStateDict, old_prefix: str, new_prefix: str -): - """Replaces the given prefix in *all* sharded keys in a given state dict. - - Errors out if some key does not begin with a given prefix. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in - old_prefix (str): prefix to be replaced in each key - new_prefix (str): new prefix - - Returns: - None: state dict is modified in place - """ - - def _replace_prefix(x): - if isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)): - if not x.key.startswith(old_prefix): - raise ValueError(f'Expected {x.key} to begin with prefix {old_prefix}') - x.key = f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9 - return x - - dict_list_map_inplace(_replace_prefix, sharded_state_dict) - - -def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[str, str]): - """Replaces prefixes *only in keys matching* with one of prefixes in the map. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in - prefix_map (Dict[str, str]): - map of old->new prefixes. The first matching prefix for each key is used - - Returns: - None: state dict is modified in place - """ - - def _replace_prefixes(x): - if not isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)): - return x - for old_prefix, new_prefix in prefix_map.items(): - if x.key.startswith(old_prefix): - x.key = ( - f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9 - ) - break - return x - - dict_list_map_inplace(_replace_prefixes, sharded_state_dict) +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Helpers for manipulating sharded tensors and sharded state dicts. """ +import logging +from contextlib import contextmanager +from time import time +from typing import Dict, Optional, Tuple + +from .dict_utils import dict_list_map_inplace, extract_matching_values +from .mapping import ( + LocalNonpersistentObject, + ShardedBase, + ShardedObject, + ShardedStateDict, + ShardedTensor, + ShardedTensorFactory, + StateDict, +) + +# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor +# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple) +_ShardId = Tuple[str, tuple, Optional[tuple]] + + +def zip_strict(*args): + """ + Alternative to Python's builtin zip(..., strict=True) (available in 3.10+). + Apart from providing functionality in earlier versions of Python is also more verbose. + (Python's zip does not print lengths, only which iterable has finished earlier) + """ + args = [list(a) for a in args] + lens = [len(a) for a in args] + assert len(set(lens)) <= 1, f"Tried to zip iterables of unequal lengths: {lens}!" + return zip(*args) + + +def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId: + """Unique id of the sharded tensor data. + + Should yield the same value for same data replicated on different ranks. + + Args: + sharded_tensor (ShardedTensor): sharded tensor representing the data shard + + Returns (tuple): unique id of a data shard + """ + f_range = sharded_tensor.flattened_range + return ( + sharded_tensor.key, + sharded_tensor.global_offset, + None if f_range is None else (f_range.start, f_range.stop), + ) + + +def _sharded_object_id(sharded_object: ShardedObject) -> _ShardId: + """Unique id of the sharded object data. + + Should yield the same value for same data replicated on different ranks. + + Args: + sharded_object (ShardedObject): sharded object representing the data shard + + Returns (tuple): unique id of a data shard + """ + return (sharded_object.key, sharded_object.global_offset, sharded_object.global_shape) + + +def extract_sharded_tensors( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only ShardedTensor objects + from a given state dict with any objects. + + Args: + sharded_state_dict: state dict possibly containing ShardedTensor objects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all ShardedTensor (keeping the original state dict structure) + - state dict with all objects other than ShardedTensor + (keeping the original state dict structure) + """ + return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor)) + + +def extract_sharded_tensors_and_factories( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects + from a given state dict with any objects. + + Args: + sharded_state_dict: + state dict possibly containing ShardedTensor and ShardedTensorFactory objects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all ShardedTensor and ShardedTensorFactory + (keeping the original state dict structure) + - state dict with all other objects (keeping the original state dict structure) + """ + return extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory)) + ) + + +def extract_sharded_tensors_or_nonpersistent( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only ShardedTensor, ShardedTensorFactory + and LocalNonpersistentObject objects from a given state dict with any objects. + + Args: + sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory + and LocalNonpersistentObject objects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject + (keeping the original state dict structure) + - state dict with all other objects (keeping the original state dict structure) + """ + return extract_matching_values( + sharded_state_dict, + lambda v: isinstance(v, (ShardedTensor, LocalNonpersistentObject, ShardedTensorFactory)), + ) + + +def extract_sharded_base( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only ShardedBase from a given state dict with any objects. + + Args: + sharded_state_dict: state dict possibly containing ShardedBase objects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all ShardedBase objects (keeping the original state dict structure) + - state dict with all other objects (keeping the original state dict structure) + """ + return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedBase)) + + +def extract_nonpersistent( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only LocalNonpersistentObjects from a given state dict. + + Args: + sharded_state_dict: state dict possibly containing LocalNonpersistentObjects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all LocalNonpersistentObjects + (keeping the original state dict structure) + - state dict with all other objects (keeping the original state dict structure) + """ + + return extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, LocalNonpersistentObject) + ) + + +def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str): + """Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict + prefix (str): prefix to be prepended + + Returns: + None: state dict is modified in-place + """ + + def add_prefix(t): + if isinstance(t, ShardedBase): + t.key = f'{prefix}{t.key}' + return t + + dict_list_map_inplace(add_prefix, sharded_state_dict) + + +def replace_prefix_for_sharding( + sharded_state_dict: ShardedStateDict, old_prefix: str, new_prefix: str +): + """Replaces the given prefix in *all* sharded keys in a given state dict. + + Errors out if some key does not begin with a given prefix. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in + old_prefix (str): prefix to be replaced in each key + new_prefix (str): new prefix + + Returns: + None: state dict is modified in place + """ + + def _replace_prefix(x): + if isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)): + if not x.key.startswith(old_prefix): + raise ValueError(f'Expected {x.key} to begin with prefix {old_prefix}') + x.key = f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9 + return x + + dict_list_map_inplace(_replace_prefix, sharded_state_dict) + + +def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[str, str]): + """Replaces prefixes *only in keys matching* with one of prefixes in the map. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in + prefix_map (Dict[str, str]): + map of old->new prefixes. The first matching prefix for each key is used + + Returns: + None: state dict is modified in place + """ + + def _replace_prefixes(x): + if not isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)): + return x + for old_prefix, new_prefix in prefix_map.items(): + if x.key.startswith(old_prefix): + x.key = ( + f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9 + ) + break + return x + + dict_list_map_inplace(_replace_prefixes, sharded_state_dict) + + +fallback_logger = logging.getLogger(__name__) +__LOGGER_NAME_STACK = [] +__LOGGER_STACK = [] + + +@contextmanager +def logger_stack(name: Optional[str] = None, current_logger: Optional[logging.Logger] = None): + """Context manager for managing logger and name stack. + + Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical + logging and contextual logger usage. Ensures the logger stack is restored afterward. + + Args: + name (str, optional): Name to add to the logger stack. Defaults to None. + current_logger (logging.Logger, optional): Logger to use. Defaults to the last logger in + the stack or a fallback if none exist. + + Yields: + Tuple[str, logging.Logger]: A tuple with the concatenated logger name stack and + the current logger for the block. + + Example: + with logger_stack("scope", logger): + logger.info("Log within 'scope'") + """ + if name: + __LOGGER_NAME_STACK.append(name) + if current_logger: + __LOGGER_STACK.append(current_logger) + last_logger = current_logger + elif __LOGGER_STACK: + last_logger = __LOGGER_STACK[-1] + else: + last_logger = fallback_logger + try: + yield ".".join(__LOGGER_NAME_STACK), last_logger + finally: + if name and __LOGGER_NAME_STACK: + __LOGGER_NAME_STACK.pop(-1) + if current_logger and __LOGGER_STACK: + __LOGGER_STACK.pop(-1) + + +@contextmanager +def debug_time( + name: str, logger: Optional[logging.Logger] = None, threshold: float = float("-inf"), level=None +): + """Simple context manager for timing functions/code blocks. + + Args: + name (str): Label describing the code being measured. + logger (logging.Logger, optional): Logger for output. Defaults to the lowest logger. + threshold (float, optional): Minimum time (seconds) to log. Skips logging if faster. + level (int, optional): Logging level. Defaults to DEBUG if `threshold` is unset; + WARNING otherwise. + """ + with logger_stack(name, logger) as (stacked_name, last_logger): + start = time() + try: + yield + finally: + result = time() - start + if result < threshold: + return + if level is None: + level = logging.DEBUG if threshold == float("-inf") else logging.WARNING + last_logger.log(level, f"{stacked_name} took {result:.4f}s") + + +def debug_msg(msg: str): + """Logs a debug message using the current logger stack. + + This function formats and logs a debug message with the current logger + and name stack, preserving context from the logger_stack context manager. + + Args: + msg (str): The message to be logged at the debug level. + + Example: + debug_msg("Checkpoint initialized") + # Logs: "scope_name Checkpoint initialized" if called within logger_stack("scope_name") + """ + with logger_stack(None, None) as (stacked_name, last_logger): + last_logger.debug(f"{stacked_name} {msg}") diff --git a/megatron/core/distributed/custom_fsdp/__init__.py b/megatron/core/distributed/custom_fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f907aca691c2098736ca541bf52a4213eaaa6601 --- /dev/null +++ b/megatron/core/distributed/custom_fsdp/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from .fully_sharded_data_parallel import FullyShardedDataParallel diff --git a/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py b/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..381e8a4184044a503f43847f8a790d5ec025e3c5 --- /dev/null +++ b/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py @@ -0,0 +1,687 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import functools +import logging +from contextlib import contextmanager +from enum import Enum, auto +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch.utils._pytree import tree_flatten, tree_unflatten + +from megatron.core import parallel_state +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.distributed.custom_fsdp.param_and_grad_buffer import ( + AllGatherPipeline, + BucketingPolicy, + GradReducePipeline, + ParamAndGradBuffer, + PrefetchOrder, +) +from megatron.core.distributed.data_parallel_base import _BaseDataParallel +from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.utils import is_float8tensor, is_submodule, log_single_rank + +logger = logging.getLogger(__name__) + + +class TrainingState(Enum): + """States of a FSDP parameter group, which are coupled with + the sharding activity of parameters and gradients during training.""" + + # From pre-forward before post-forward, where parameters should be unsharded + FORWARD = auto() + # Prior to backward computation, where parameters should be unsharded + PRE_BACKWARD = auto() + # After backward computation, where gradients should be re-sharded + POST_BACKWARD = auto() + # Before and after module forward computaton or before pre-backward and + # after post-backward states, where no un/sharding activity happens + IDLE = auto() + + +class FullyShardedDataParallel(_BaseDataParallel): + """Fully Sharded Data Parallel training for MCore models. + + A distributed training wrapper that shards model parameters, gradients and optimizer + states across data parallel workers. Integrates seamlessly with MCore's tensor + and expert parallelism features. + + We supports following modes: + - no_shard: Traditional data parallel training without parameter sharding. + - optim: Shards optimizer states, this is conceptually close to "ZeRO-1", and + main weights for mixed precision training, meanwhile the following `optim_grads` + and `optim_grads_params` will also sharding main weights + during mixed-precision training, omitted without detailed notation. + - optim_grads: Shards gradients and optimizer states, this is conceptually close to "ZeRO-2". + - optim_grads_params: Shards parameters, gradients and optimizer states, this + is conceptually close to "ZeRO-3". + + Key Features: + - Compatible with MCore's tensor, context and expert parallelism + - Automatic mixed precision training (BF16/FP8) + - Gradient accumulation and bucketing + - Optimized activation recompute with shard-aware communication: When recomputing + a whole Transformer layer, gather parameters once for both the recomputation + and backward computation + - Compatible with MCore's distributed checkpointing + + Args: + config: Transformer config object. + ddp_config: FullyShardedDataParallel config object. + module: Underlying model. + fsdp_unit_modules: List of modules that should be treated as FSDP Unit, + i.e., the minimum releasable model unit. If not provided, defaults to + [TransformerLayer, LanguageModelEmbedding] for GPT-like models. + disable_bucketing: If true, force assign all parameters to a single bucket. If false, + use standard bucketing policy: assign parameters to smaller buckets and all-reduce + per bucket. + Examples: + >>> model = GPTModel(config) + >>> model = FullyShardedDataParallel( + ... config, + ... model, + ... ddp_config, + ... fsdp_unit_modules = [TransformerLayer, LanguageModelEmbedding], + ... ) + """ + + # TODO: add hybrid FSDP (shard model states in a partial DP domain) + def __init__( + self, + config: TransformerConfig, + ddp_config: DistributedDataParallelConfig, + module: torch.nn.Module, + fsdp_unit_modules: Optional[List[torch.nn.Module]] = None, + disable_bucketing: bool = False, + device: Optional[torch.device] = None, + ): + super().__init__(config=config, module=module) + if has_config_logger_enabled(config): + log_config_to_disk(config, locals(), prefix=type(self).__name__) + + self.module = module + self.ddp_config = ddp_config + log_single_rank( + logger, + logging.INFO, + f'Setting up DistributedDataParallel with config {self.ddp_config}', + ) + + self.bucket_size = self.ddp_config.bucket_size + if disable_bucketing: + self.bucket_size = None + self.device = device if device else torch.cuda.current_device() + + self.param_to_bucket_group = {} + + if fsdp_unit_modules is not None: + self.fsdp_unit_modules = fsdp_unit_modules + else: + self.fsdp_unit_modules = [TransformerLayer] + if not getattr(self.module, "share_embeddings_and_output_weights", False): + self.fsdp_unit_modules.append(LanguageModelEmbedding) + self.main_weights = True + self.data_parallel_group = parallel_state.get_data_parallel_group( + with_context_parallel=True + ) + self.expert_data_parallel_group = parallel_state.get_expert_data_parallel_group() + + # Determine if we should delay the gradient reduction. + self.is_delay_grad_reduce = self.ddp_config.data_parallel_sharding_strategy in [ + "no_shard", + "optim", + ] + + if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params": + assert self.ddp_config.overlap_param_gather + if not self.is_delay_grad_reduce: + assert self.ddp_config.overlap_grad_reduce + self._init_fsdp_param_and_grad_buffer() + self._register_fsdp_hooks(self.module) + + # Delete references to weight_tensor if they exist since we don't want two parameter copies + # if we re-mapped parameters (which happens when we use the distributed optimizer). + # This is a temporary workaround around a TE bug that is fixed with + # https://github.com/NVIDIA/TransformerEngine/pull/719. + @torch.no_grad() + def unmap_weight_tensor(m): + if hasattr(m, 'weight_tensor'): + m.weight_tensor = None + + self.module.apply(unmap_weight_tensor) + + def _init_fsdp_param_and_grad_buffer(self): + if self.config.calculate_per_token_loss: + # We don't need to scale the gradients in this case. + gradient_scaling_factor = None + expert_gradient_scaling_factor = None + else: + if self.ddp_config.average_in_collective: + # FIXME(@jianbinc): Will fix this issue based on Parallel Folding's EDP patch MR. + raise Exception("Not supported") + else: + data_parallel_world_size = parallel_state.get_data_parallel_world_size( + with_context_parallel=True + ) + gradient_scaling_factor = 1.0 / data_parallel_world_size + expert_gradient_scaling_factor = 1.0 / data_parallel_world_size + + # Initialize the param and grad buffer. + self.data_parallel_sharding_strategy = self.ddp_config.data_parallel_sharding_strategy + self.param_to_name = {p: name for name, p in self.module.named_parameters()} + self.param_and_grad_buffer = ParamAndGradBuffer( + self.ddp_config, + self.module, + bucketing_policy=BucketingPolicy( + suggested_bucket_size=self.bucket_size, + fsdp_unit_modules=( + # Only when model weights need to be sharded, we need to + # identify the minimum releasable model unit, which is the + # FSDP Unit Module. + self.fsdp_unit_modules + if self.data_parallel_sharding_strategy == "optim_grads_params" + else [] + ), + data_parallel_sharding_strategy=self.data_parallel_sharding_strategy, + ), + data_parallel_group=self.data_parallel_group, + expert_data_parallel_group=self.expert_data_parallel_group, + preserve_fp32_weights=self.ddp_config.preserve_fp32_weights, + grad_reduce_in_fp32=self.ddp_config.grad_reduce_in_fp32, + gradient_scaling_factor=gradient_scaling_factor, + expert_gradient_scaling_factor=expert_gradient_scaling_factor, + device=self.device, + reset_parameters_for_meta_device_init_module=self.config.init_model_with_meta_device, + ) + self.param_and_grad_buffer + + self.side_stream_for_buffer_copy_and_grad_accum = torch.cuda.Stream() + + # Initialize the reduce-scatter pipeline. + self.grad_reduce_pipeline = GradReducePipeline( + self.param_and_grad_buffer, cuda_stream=self.side_stream_for_buffer_copy_and_grad_accum + ) + + # Initialize the all-gather pipeline. + self.all_gather_pipeline = AllGatherPipeline(self.param_and_grad_buffer) + + self.suggested_RS_queue_capacity = self.ddp_config.suggested_communication_unit_size + self.suggested_AG_prefetch_size = self.ddp_config.suggested_communication_unit_size + + def _register_fsdp_hooks(self, root_module): + """Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model. + + This function sets up various hooks required for FSDP operations, including parameter + resharding/unsharding and gradient handling. The registered hooks are: + - Pre-forward hook: Unshards parameters before forward pass + - Post-forward hook: Reshards parameters after forward pass + - Pre-backward hook: Unshards parameters before backward pass + - Post-backward hook: Reshards parameters after backward pass + - Gradient accumulation hook: Handles gradient accumulation and reduction across devices + + Args: + root_module: The PyTorch module to register FSDP hooks on + + Note: + These hooks are essential for FSDP's memory efficiency as they manage: + 1. Dynamic parameter sharding/unsharding to reduce memory footprint + 2. Proper gradient synchronization across distributed processes + 3. Gradient accumulation for large batch training + + Returns: + None + """ + + # Initialize module training state. + for m in root_module.modules(): + setattr(m, "_training_state", TrainingState.IDLE) + + self.forward_pre_hooks = {} + self.forward_hooks = {} + self.backward_pre_hooks = {} + + """ + An FSDP unit is a module designed to manage the lifecycle of model parameters + in Fully Sharded Data Parallel (FSDP) training. It ensures that parameters + are only used within the module and are released immediately after + the forward and backward computations are completed. + This approach is crucial for efficient memory management, as releasing + parameters too early can lead to issues if other computations depend on them. + + `optim` and `optim_grads` do not require FSDP units because they do not + shard model parameters. + """ + if self.data_parallel_sharding_strategy != "optim_grads_params": + fsdp_unit_modules = [] + else: + fsdp_unit_modules = self.fsdp_unit_modules + + def release_module_parameters(module, *unused): + for param in module.parameters(): + bucket_id = self.param_and_grad_buffer.param_to_param_group[param] + self.all_gather_pipeline.release_bucket(bucket_id) + + if not self.ddp_config.keep_fp8_transpose_cache_when_using_custom_fsdp: + release_params_fp8_transpose_cache(module.parameters()) + + def release_params_fp8_transpose_cache(params): + for param in params: + if is_float8tensor(param): + param._transpose_invalid = True + param._transpose = None + + def all_gather_module_parameters( + module, + *unused, + prefetch=True, + prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER, + wait_bucket_ready=True, + ): + wait_list = [] + ag_pipeline = self.all_gather_pipeline + for param in module.parameters(): + bucket_id = self.param_and_grad_buffer.param_to_param_group[param] + ag_pipeline.queue_bucket_to_all_gather( + bucket_id, + prefetch=prefetch, + prefetch_order=prefetch_order, + suggested_AG_prefetch_size=self.suggested_AG_prefetch_size, + ) + wait_list.append(bucket_id) + + if wait_bucket_ready: + for bucket_id in wait_list: + ag_pipeline.wait_bucket_ready(bucket_id) + + def _post_backward(module, *unused): + release_module_parameters(module) + module._training_state = TrainingState.IDLE + + def _pre_forward(module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]): + input_training_state = module._training_state + fsdp_forward_prefetch = True + if input_training_state == TrainingState.PRE_BACKWARD: + # In activation recomputation case, we need to cancel forward prefetch. + fsdp_forward_prefetch = False + else: + module._training_state = TrainingState.FORWARD + + if isinstance(module, tuple(fsdp_unit_modules)): + wait_list = [] + for param in module.parameters(): + bucket_id = self.param_and_grad_buffer.param_to_param_group[param] + self.all_gather_pipeline.queue_bucket_to_all_gather( + bucket_id, + prefetch=fsdp_forward_prefetch, + suggested_AG_prefetch_size=self.suggested_AG_prefetch_size, + ) + wait_list.append(bucket_id) + for bucket_id in wait_list: + self.all_gather_pipeline.wait_bucket_ready(bucket_id) + + if not torch.is_grad_enabled(): + return args, kwargs + + # Register the backward function to release the parameters. + args_list, args_spec = tree_flatten(args) + kwargs_list, kwargs_spec = tree_flatten(kwargs) + args_kwargs_list = list(args_list) + list(kwargs_list) + inp_tensor_indices: List[int] = [] + inp_tensors: List[torch.Tensor] = [] + for i, obj in enumerate(args_kwargs_list): + if torch.is_tensor(obj) and obj.requires_grad: + inp_tensor_indices.append(i) + inp_tensors.append(obj) + if len(inp_tensors) == 0: + return args, kwargs + inp_tensors = RegisterFSDPBackwardFunction.apply( + functools.partial(_post_backward, module), *inp_tensors + ) + for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors): + args_kwargs_list[inp_tensor_idx] = inp_tensor + args_list = args_kwargs_list[: len(args_list)] + kwargs_list = args_kwargs_list[len(args_list) :] + args = tree_unflatten(args_list, args_spec) + kwargs = tree_unflatten(kwargs_list, kwargs_spec) + + return args, kwargs + else: + # All-gather the parameters in every forward pass for FSDP. + for param in module.parameters(recurse=False): + bucket_id = self.param_and_grad_buffer.param_to_param_group[param] + self.all_gather_pipeline.queue_bucket_to_all_gather( + bucket_id, + prefetch=fsdp_forward_prefetch, + suggested_AG_prefetch_size=self.suggested_AG_prefetch_size, + ) + for param in module.parameters(recurse=False): + bucket_id = self.param_and_grad_buffer.param_to_param_group[param] + self.all_gather_pipeline.wait_bucket_ready(bucket_id) + + return args, kwargs + + if self.ddp_config.overlap_param_gather: + fsdp_modules = [] + for name, module in root_module.named_modules(): + if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params": + if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules): + continue + + if isinstance(module, tuple(fsdp_unit_modules)): + fsdp_modules.append(module) + + self.forward_pre_hooks[f'module {name} parameter all-gather'] = ( + module.register_forward_pre_hook(_pre_forward, prepend=True, with_kwargs=True) + ) + + def _pre_backward(module: nn.Module, *unused): + module._training_state = TrainingState.PRE_BACKWARD + if isinstance(module, tuple(fsdp_unit_modules)): + all_gather_module_parameters( + module, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER + ) + + def _root_pre_backward(module: nn.Module, *unused): + """Marks the module's training state as 'pre_backward' before the + backprop, this function is registered on the root module. + + This marking enables us to determine whether forward pass needs to + perform reshard/unshard operations in activation recomputation + scenarios. + """ + for module in root_module.modules(): + if isinstance(module, tuple(fsdp_unit_modules)): + module._training_state = TrainingState.PRE_BACKWARD + for param in module.parameters(): + bucket_id = self.param_and_grad_buffer.param_to_param_group[param] + self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True) + self.all_gather_pipeline.release_bucket(bucket_id) + + def _post_forward(module: nn.Module, input: Any, output: Any): + # When composing with module-hook-based activation checkpointing, the + # post-backward hook is responsible for the reshard + if module._training_state == TrainingState.PRE_BACKWARD: + return output + + release_module_parameters(module) + module._training_state = TrainingState.IDLE + + return output + + def _release_module_fp8_transpose_cache(module: nn.Module, *unused): + release_params_fp8_transpose_cache(module.parameters(recurse=False)) + + if self.data_parallel_sharding_strategy == "optim_grads_params": + fsdp_modules = [] + for name, module in root_module.named_modules(): + if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules): + continue + + if isinstance(module, tuple(fsdp_unit_modules)): + fsdp_modules.append(module) + self.forward_hooks[f"release module {name} parameters"] = ( + module.register_forward_hook(_post_forward, prepend=False) + ) + self.backward_pre_hooks[f"all-gather module {name} parameters"] = ( + module.register_full_backward_pre_hook(_pre_backward) + ) + elif not self.ddp_config.keep_fp8_transpose_cache_when_using_custom_fsdp: + self.forward_hooks[f"remove module {name} fp8 transpose cache"] = ( + module.register_forward_hook( + _release_module_fp8_transpose_cache, prepend=False + ) + ) + self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook( + _root_pre_backward + ) + + def _make_param_hook(param: torch.nn.Parameter): + """ + Creates the all-reduce / reduce-scatter hook for backprop. + """ + + wait_previous_grad_reduce = not self.is_delay_grad_reduce + + # FIXME: Use insert forward op to replace grad acc hook, which will + # be lost after parameter data movement. For example, module.cuda() + # will cause the registered grad acc hook to be lost. + def param_hook(*unused): + if param.requires_grad: + if self.ddp_config.overlap_grad_reduce: + assert ( + param.grad is not None + ), 'param.grad being None is not safe when overlap_grad_reduce is True' + + if param.grad is not None and ( + not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False) + ): + if self.is_delay_grad_reduce: + param.main_grad.add_(param.grad.data) + else: + param.main_grad.copy_(param.grad.data) + param.grad = None + + if self.ddp_config.overlap_grad_reduce and ( + not self.is_delay_grad_reduce or self.is_last_microbatch + ): + gr_pipeline = self.grad_reduce_pipeline + bucket_id = self.param_and_grad_buffer.param_to_param_group[param] + gr_pipeline.place_bucket(bucket_id) + go_rs = gr_pipeline.mark_item_ready(param, async_rs=True) + if go_rs and wait_previous_grad_reduce: + gr_pipeline.wait_for_previous_grad_reduce( + recommeded_queue_capacity=self.suggested_RS_queue_capacity + ) + + return param_hook + + # Register backward gradient accumulation hook for each parameter. + self.grad_accs = [] + for param in root_module.parameters(): + bucket_id = self.param_and_grad_buffer.param_to_param_group[param] + wbuf = self.param_and_grad_buffer.parameter_groups[bucket_id].model_weight_buffer + if param.requires_grad: + if wbuf and wbuf.is_data_distributed: + wbuf.fetch_bucket(and_allocate_params_data=True) + + # Expand so we get access to grad_fn. + param_tmp = param.expand_as(param) + # Get the gradient accumulator function. + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(_make_param_hook(param)) + self.grad_accs.append(grad_acc) + + if wbuf and wbuf.is_data_distributed: + wbuf.free_bucket_storage() + + @contextmanager + def no_sync(self): + """ + Context manager that turns off gradient synchronization. + For grads shard mode there will actually always be gradient sync happening. + """ + # FIXME: Better handling of grads shard mode and no_sync in the training loop so that + # the code doesn't bog down developers. + self.is_last_microbatch = False + try: + yield + finally: + self.is_last_microbatch = True + + def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False): + """ + Initiates param sync (all-gather) communication operations for all model parameters. + + By default, when overlap_param_gather is set to True, dispatches asynchronous communication + calls; when overlap_param_gather is set to False, calls synchronous communication + ops. Can override this default behavior using flags below. + + Args: + force_sync (bool, optional): force synchronous collective regardless of + other settings. + force_dispatch (bool, optional): force dispatch regardless of other settings. + """ + if not force_sync and self.ddp_config.overlap_param_gather: + # All-gather the first bucket before the forward pass. + self.all_gather_pipeline.queue_bucket_to_all_gather(bucket_id=0, prefetch=False) + else: + self.all_gather_pipeline.reset() + for bucket_id in range(self.all_gather_pipeline.num_buckets): + self.all_gather_pipeline.all_gather_bucket_and_set_items( + bucket_id=bucket_id, async_op=True + ) + group = self.param_and_grad_buffer.parameter_groups[bucket_id] + if group.model_weight_buffer is None: + continue + + if group.model_weight_buffer.is_data_distributed: + # If model weight is sharded, we wait for the all-gather to complete and + # then release the bucket immediately to save memory usage. + self.all_gather_pipeline.wait_bucket_ready(bucket_id) + for bucket_id in range(self.all_gather_pipeline.num_buckets): + self.all_gather_pipeline.wait_bucket_ready(bucket_id) + + def start_grad_sync(self, *unused): + """ + Initiates grad sync (all-reduce or reduce-scatter) communication operations + for all model gradients. + + When overlap_grad_reduce is set to True, dispatches asynchronous communication + calls. When overlap_grad_reduce is set to False, calls synchronous + communication ops. + """ + if not self.ddp_config.overlap_grad_reduce: + if self.data_parallel_sharding_strategy == "no_shard": + self.param_and_grad_buffer.all_reduce_gradients( + async_op=self.ddp_config.overlap_grad_reduce + ) + else: + self.param_and_grad_buffer.reduce_scatter_gradients() + + def finish_grad_sync(self): + """ + Finishes grad sync (all-reduce or reduce-scatter) communication operations + for all model gradients. + + When overlap_grad_reduce is set to True, waits for asynchronous communication + calls to complete. When overlap_grad_reduce is set to False, calls synchronous + communication ops. + """ + if self.ddp_config.overlap_grad_reduce: + self.grad_reduce_pipeline.wait_for_previous_grad_reduce(0) + self.grad_reduce_pipeline.reset() + else: + self.start_grad_sync() + + self.param_and_grad_buffer.update_main_grads() + + if self.ddp_config.overlap_param_gather: + self.all_gather_pipeline.reset() + + def optimizer_named_parameters(self) -> List[Tuple[str, torch.Tensor]]: + """ + Returns a list of tuples containing the main weights and their corresponding names + for mixed-precision training, to be used by the optimizer for updates. + + Returns: + List[Tuple[str, torch.Tensor]]: A list of tuples, where each tuple + contains a main weight tensor and its corresponding name. + """ + return self.param_and_grad_buffer.optimizer_named_parameters + + def scale_gradients(self, scaling_factor: float): + """Scale all gradients inside the buffers by `scaling_factor`.""" + self.param_and_grad_buffer.scale_gradients(scaling_factor) + + def zero_grad_buffer(self): + """ + Zeros out all grad buffers. Needs to be called at the beginning of each + training iteration. + """ + for param in self.module.parameters(): + if param.requires_grad: + param.grad_added_to_main_grad = False + self.param_and_grad_buffer.zero_grad() + + def broadcast_params(self): + """ + Syncs parameters across all DP ranks. + """ + for param in self.module.parameters(): + is_expert_parallel = not getattr(param, 'allreduce', True) + + if is_expert_parallel: + data_parallel_group = parallel_state.get_data_modulo_expert_parallel_group( + with_context_parallel=True + ) + else: + data_parallel_group = parallel_state.get_data_parallel_group( + with_context_parallel=True + ) + torch.distributed.broadcast( + param.data, + src=torch.distributed.get_global_rank(data_parallel_group, 0), + group=data_parallel_group, + ) + + def load_state_dict(self, state_dict, strict=True): + """ + Copies parameters and buffers from state_dict into the wrapped module and its + descendants. If strict is True, then the keys of state_dict must exactly match + the keys returned by this module’s state_dict() function. + """ + if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params": + # make a copy of the state_dict to avoid modifying the input state_dict + state_dict = state_dict.copy() + state_dict_extra_states = {} + for key in list(state_dict.keys()): + if key.endswith("_extra_state"): + state_dict_extra_states[key] = state_dict[key] + del state_dict[key] + self.module.load_state_dict(state_dict_extra_states, strict=False) + + prefix = "module." + buffer = self.param_and_grad_buffer + for param_groups in buffer.parameter_groups: + wbuf = param_groups.model_weight_buffer + for model_param in wbuf.params: + if is_float8tensor(model_param): + fp8_meta = model_param._fp8_meta['scaling_fwd'] + fp8_meta_index = model_param._fp8_meta_index + model_param._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index]) + + param_name = f"{buffer.param_to_name[model_param]}"[len(prefix) :] + if param_name in state_dict: + if wbuf and wbuf.is_data_distributed: + model_param.fully_shard_param_local_shard.data.copy_( + state_dict[param_name] + ) + else: + model_param.data.copy_(state_dict[param_name]) + del state_dict[param_name] + self.module.load_state_dict(state_dict, strict=False) + return + self.module.load_state_dict(state_dict, strict=strict) + + +class RegisterFSDPBackwardFunction(torch.autograd.Function): + """ + Register a backward function that will be called after the backward pass + of the model. This function is used to release the parameters after the + backward pass. + """ + + @staticmethod + def forward(ctx, post_backward, *inputs: torch.Tensor): + ctx.post_backward = post_backward + return inputs + + @staticmethod + def backward(ctx, *grads: torch.Tensor): + ctx.post_backward() + return (None,) + grads diff --git a/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..a29768a6ab050cd085b2511408ac3adb00f6de6f --- /dev/null +++ b/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py @@ -0,0 +1,1971 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import dataclasses +import gc +import inspect +import logging +import math +import traceback +import warnings +from collections import namedtuple +from contextlib import ExitStack +from enum import Enum +from typing import Any, List, Optional, Tuple + +import torch + +from megatron.core import parallel_state +from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig +from megatron.core.tensor_parallel import get_cuda_rng_tracker +from megatron.core.utils import ( + is_float8tensor, + is_submodule, + is_te_min_version, + log_on_each_pipeline_stage, +) + +try: + from transformer_engine.pytorch import fp8_model_init + + # This will be used when "--use-fp8-params" is enabled. + # When BF16/FP16 parameters don't exist, we need to cast the FP32 main parameters to + # FP8 directly in the optimizer. + from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 +except: + pass + +try: + from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +except: + pass + + +logger = logging.getLogger(__name__) + + +def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: + """Alternate to ``assert`` when in the backward context to print the error + message ``s`` since otherwise, it is swallowed. + """ + if not cond: + print(s) + traceback.print_stack() + if raise_assertion_error: + raise AssertionError(s) + + +def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None: + """ + Allocate storage for ``tensor`` with the given size. + + Returns: + bool: ``True`` if this method allocated storage and ``False`` if the + storage was already allocated. + """ + with torch.no_grad(): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_allocated = tensor._typed_storage()._size() == size.numel() + if not already_allocated: + tensor_storage_size = tensor._typed_storage()._size() + _p_assert( + tensor_storage_size == 0, + "Tensor storage should have been resized to be 0 but got PLACEHOLDEr", + ) + tensor._typed_storage()._resize_(size.numel()) + + +def _free_storage(tensor: torch.Tensor): + """ + Frees the underlying storage of ``tensor``. + + Returns: + bool: ``True`` if the method freed the storage and ``False`` if the + storage was already freed. + """ + with torch.no_grad(): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_freed = tensor._typed_storage()._size() == 0 + if not already_freed: + _p_assert( + tensor.storage_offset() == 0, + "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" + f"storage offset: {tensor.storage_offset()}\n" + f"storage size: {tensor._typed_storage()._size()}\n" + f"tensor shape: {tensor.shape}", + ) + tensor._typed_storage()._resize_(0) + + +TensorItemIndex = namedtuple( + 'TensorItemIndex', ['global_data_index', 'size', 'item_id', 'bucket_id', 'shape'] +) +BucketIndex = namedtuple('BucketIndex', ['bucket_id', 'global_data_index', 'size', 'items']) +ShardBucketIndex = namedtuple( + 'ShardBucketIndex', + ['bucket_id', 'global_data_index', 'local_data_index', 'bucket_data_index', 'size'], +) + + +@dataclasses.dataclass +class BucketingPolicy: + """ + A policy for bucketing in Fully Sharded Data Parallel (FSDP) training. + + Attributes: + suggested_bucket_size (int): The suggested size of each bucket in num of elements. + fsdp_unit_modules (list): A list of module classes that are treated as a + single unit for FSDP bucketing. + data_parallel_sharding_strategy (str): The strategy used for sharding + data parallel modules. + + Note: + This policy is used to configure the bucketing behavior in FSDP training. + """ + + suggested_bucket_size: Optional[int] = 40_000_000 + fsdp_unit_modules: List[torch.nn.Module] = dataclasses.field(default_factory=list) + data_parallel_sharding_strategy: str = 'no_shard' + + +def _pad(number_to_be_padded: int, divisor: int) -> int: + return int(math.ceil(number_to_be_padded / divisor) * divisor) + + +def build_data_parallel_buffer_index( + elements: List[torch.Size], + data_parallel_rank: int, + data_parallel_world_size: int, + is_data_distributed: bool, + ddp_config: DistributedDataParallelConfig, + bucket_id: int = 0, +) -> Tuple[int, List[tuple], List[tuple], List[tuple]]: + """ + Assuming that all input tensor elements are consecutively compose a global + buffer, give the index range of every tensor, every bucket and every in + bucket local buffer. + + Args: + elements (List[torch.Size]): List of input tensor. + data_parallel_rank (int): Rank of the current process in the data parallel group. + data_parallel_world_size (int): World size of the data parallel group. + bucket_id (int, optional): The id of the bucket. Defaults to 0. + + Returns: + Tuple[int, List[tuple], List[tuple], List[tuple]]: The index range of every tensor, + every bucket and every in bucket local buffer. + """ + + def _pad_if_needed(data_index: int) -> int: + """ + Pads data indices if using distributed optimizer (to ensure uniform sharding). + """ + if ddp_config.data_parallel_sharding_strategy != 'no_shard': + # Workaround for TE bug causing cuBLAS to pick an incompatible algorithm. + # This also helps cuBLAS pick more efficient algorithms for GEMMs. + # We now ensure that all buckets start at a memory address that is 256-byte + # aligned (128 values since params and grads use >= 16-bit precision). + return _pad(data_index, math.lcm(data_parallel_world_size, 128)) + return data_index + + def add_item(item_id, item, bucket, item_index_map, bucket_id): + bucket.append(item) + bucket_size = sum([it.numel() for it in bucket]) + item_index_map.append( + TensorItemIndex( + data_index + bucket_size - item.numel(), + item.numel(), + item_id=item_id, + bucket_id=bucket_id, + shape=item, + ) + ) + + item_index_map = [] + bucket = [] + data_index = 0 + for item_id, item in enumerate(elements): + add_item(item_id, item, bucket, item_index_map, bucket_id) + + bucket_size = sum([it.numel() for it in bucket]) + bucket_size = _pad_if_needed(bucket_size) + bucket_index = BucketIndex( + bucket_id, + data_index, + bucket_size, + items=list(filter(lambda x: x.bucket_id == bucket_id, item_index_map)), + ) + + shard_size = bucket_index.size // data_parallel_world_size + bucket_data_index = shard_size * data_parallel_rank + global_data_index = bucket_index.global_data_index + bucket_data_index + + if is_data_distributed: + shard_bucket_index = ShardBucketIndex( + bucket_id, global_data_index, 0, bucket_data_index, shard_size + ) + else: + shard_bucket_index = ShardBucketIndex( + bucket_id, global_data_index, global_data_index, bucket_data_index, shard_size + ) + + return item_index_map, bucket_index, shard_bucket_index + + +@dataclasses.dataclass +class Bucket: + """ + A container for holding data in Fully Sharded Data Parallel (FSDP) training. + + Attributes: + data (torch.Tensor): A tensor containing the data elements + grouped together in a bucket. + data_operation_event (Optional[torch.cuda.Event]): An optional CUDA event + used to synchronize data operations. + status (Any): An optional status object used to track the state of the bucket. + + Note: + Buckets are used to optimize communication in FSDP training by + grouping small tensors together. + """ + + data: torch.Tensor + data_operation_event: Optional[torch.cuda.Event] = None + status: Any = None + + +class TemporaryBucketAllocator: + """ + A utility class for managing temporary buckets (buffers) used in FSDP + operations like parameters unshard and gradients reduction. + + This allocator handles the dynamic allocation and deallocation of temporary memory buffers + needed during FSDP (Fully Sharded Data Parallel) operations, particularly for parameters + unshard and gradients reduction. It helps optimize memory usage by allowing temporary + buckets to be released when no longer needed. + + Key Features: + - Dynamic allocation of temporary buckets for FSDP operations + - Memory-efficient management of temporary buffers + - Support for both parameters unshard and gradients reduction operations + - Automatic cleanup of unused buckets to save memory + + Usage: + ```python + # Create an allocator instance + allocator = TemporaryBucketAllocator(name="gpt_parameters") + + # Allocate a temporary bucket + temp_bucket = allocator.allocate(size=1024, dtype=torch.float32) + + # Use the temporary bucket for FSDP operations + # ... perform all-gather or reduce-scatter ... + + # Free the bucket when done + allocator.free(temp_bucket) + ``` + + Note: + It's important to release temporary buckets after use to prevent memory leaks + and optimize memory usage during training. + """ + + def __init__(self): + self.buckets = {} + + def allocate( + self, bucket_id: int, size: int, dtype: torch.dtype, device: torch.device + ) -> Bucket: + """ + allocate a temporary bucket. + """ + if bucket_id not in self.buckets: + self.buckets[bucket_id] = Bucket(data=torch.empty(size, dtype=dtype, device=device)) + return self.buckets[bucket_id] + + def free(self, bucket_id: int): + """ + free a temporary bucket. + """ + if bucket_id in self.buckets: + _free_storage(self.buckets[bucket_id].data) + del self.buckets[bucket_id] + + +class StorageResizeBasedBucketAllocator(TemporaryBucketAllocator): + """ + A specialized temporary bucket allocator that resizes the storage of temporary buckets + based on the required size. + """ + + def __init__(self): + self.buckets = {} # {bucket_id: Bucket} + + def allocate( + self, bucket_id: int, size: int, dtype: torch.dtype, device: torch.device + ) -> Bucket: + """ + allocate a temporary bucket. + """ + if bucket_id not in self.buckets: + self.buckets[bucket_id] = Bucket(data=torch.empty(size, dtype=dtype, device=device)) + bucket = self.buckets[bucket_id] + _alloc_storage(bucket.data, torch.Size([size])) + return bucket + + def free(self, bucket_id: int): + """ + free a temporary bucket. + """ + if bucket_id in self.buckets: + _free_storage(self.buckets[bucket_id].data) + + +class RotaryBucketAllocator(TemporaryBucketAllocator): + """A specialized temporary bucket allocator that implements a circular buffer recycling strategy + to minimize memory fragmentation in FSDP operations. + + RotaryBucketAllocator extends TemporaryBucketAllocator by maintaining a limited pool of + pre-allocated buffers that are reused in a circular manner. This approach helps prevent + memory fragmentation that typically occurs with frequent allocation and deallocation of + temporary buffers during FSDP operations. + + Key Features: + - Circular buffer recycling strategy for memory efficiency + - Reduced memory fragmentation compared to dynamic allocation + - Pre-allocated buffer pool for faster access + - Automatic buffer reuse without explicit deallocation + + Usage: + ```python + # Create a rotary allocator + allocator = RotaryBucketAllocator(name="gpt_parameters") + + # Get a temporary buffer from the pool + temp_bucket = allocator.allocate(size=1024, dtype=torch.float32) + + # Use the temporary bucket for FSDP operations + # ... perform all-gather or reduce-scatter ... + + # Free the bucket when done, make it in idle buffer pool + allocator.free(temp_bucket) + ``` + """ + + def __init__(self, name: str): + self.name = name + self.num_global_buffer = 0 + self.idle_buffer = [] # [buffer_id] + self.using_buffer = {} # {bucket_id: buffer_id} + self.buckets = {} + + def allocate( + self, bucket_id: int, size: int, dtype: torch.dtype, device: torch.device + ) -> Bucket: + """ + allocate a temporary bucket. + """ + + def _get_global_buffer(buffer_id: int): + return parallel_state.get_global_memory_buffer().get_tensor( + [size], dtype=dtype, name=self._get_gbuf_name(buffer_id) + ) + + if bucket_id in self.using_buffer: + buffer_id = self.using_buffer[bucket_id] + return Bucket(data=_get_global_buffer(buffer_id)) + + if len(self.idle_buffer) == 0: + # allocate new buffer + buffer_id = self.num_global_buffer + self.num_global_buffer += 1 + self.idle_buffer.append(buffer_id) + + buffer_id = self.idle_buffer.pop(0) + self.using_buffer[bucket_id] = buffer_id + return Bucket(data=_get_global_buffer(buffer_id)) + + def _get_gbuf_name(self, buffer_id: int): + return f"{self.name}_{buffer_id}" + + def free(self, bucket_id: int): + """ + free a temporary bucket. + """ + if bucket_id in self.using_buffer: + buffer_id = self.using_buffer.pop(bucket_id) + self.idle_buffer.append(buffer_id) + + +class DataParallelBuffer: + """ + A class that manages the data parallel buffer for Fully Sharded Data Parallel (FSDP) training. + """ + + def __init__( + self, + ddp_config: DistributedDataParallelConfig, + params: List[torch.nn.Parameter], + is_data_distributed: bool, + bucket_id: int, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + data_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + temporary_bucket_allocator: Optional[TemporaryBucketAllocator] = None, + init_meta_only: bool = False, + is_dtype_float8: bool = False, + gradient_scaling_factor: Optional[float] = None, + ) -> None: + self.ddp_config = ddp_config + self.params = params + _param_dtype = {p.dtype for p in self.params} + assert len(_param_dtype) == 1, f'params have different dtypes: {_param_dtype}' + self.is_data_distributed = is_data_distributed + self.bucket_id = bucket_id + self.dtype = dtype if dtype else next(iter(_param_dtype)) + self.device = device + self.data_parallel_group = data_parallel_group + self.dp_rank = torch.distributed.get_rank(group=self.data_parallel_group) + self.dp_world_size = torch.distributed.get_world_size(group=self.data_parallel_group) + self.temporary_bucket_allocator = ( + temporary_bucket_allocator if temporary_bucket_allocator else TemporaryBucketAllocator() + ) + self.is_dtype_float8 = is_dtype_float8 + self.gradient_scaling_factor = gradient_scaling_factor + + (self.item_index_map, self.bucket_index, self.shard_bucket_index) = ( + build_data_parallel_buffer_index( + [p.shape for p in self.params], + self.dp_rank, + self.dp_world_size, + is_data_distributed, + ddp_config, + bucket_id=bucket_id, + ) + ) + + self.data_size = ( + self.bucket_index.size if not is_data_distributed else self.shard_bucket_index.size + ) + if init_meta_only: + self.data = None + else: + self.data = torch.empty(self.data_size, dtype=self.dtype, device=device) + + self.param_idx = {p: i for i, p in enumerate(self.params)} + self.placeholder_bucket = None + self.placeholder_items = {} + + def fetch_bucket( + self, dtype: Optional[torch.dtype] = None, and_allocate_params_data: bool = False + ) -> Bucket: + """ + Fetch a communication buffer for data-parallel operations. + + The size of the bucket is defined by the `DataParallelBuffer` instance. + If `and_allocate_params_data` is True, this method resets the parameter + data stored in the `DataParallelBuffer` instance. + + Args: + dtype (Optional[torch.dtype], optional): The data type of the tensor + to fetch a buffer for. Defaults to None. + and_allocate_params_data (bool, optional): Whether to allocate and + reset parameter data. Defaults to False. + + Returns: + Bucket: The communication buffer for the specified data type. + """ + if dtype is None: + dtype = self.dtype + bucket_index = self.bucket_index + + if not self.is_data_distributed and dtype == self.dtype: + bucket = Bucket( + data=self.data[ + bucket_index.global_data_index : bucket_index.global_data_index + + bucket_index.size + ] + ) + else: + bucket = self.temporary_bucket_allocator.allocate( + bucket_id=bucket_index.bucket_id, + size=bucket_index.size, + dtype=dtype, + device=self.device, + ) + + if and_allocate_params_data: + for p in self.params: + item_id = self.param_idx[p] + if is_float8tensor(p): + p._data = self.get_item_from_bucket(bucket, item_id).view(p.shape) + else: + p.data = self.get_item_from_bucket(bucket, item_id).view(p.shape) + + return bucket + + def free_bucket_storage(self, and_free_params_data: bool = False): + """ + Release the storage of a temporary communication bucket. + + If the bucket is temporary, this method frees its storage. + If `and_free_params_data` is True, this method also releases the storage + of the parameter data stored in the `DataParallelBuffer` instance. + + Args: + and_free_params_data (bool, optional): Whether to also release the + storage of the parameter data. Defaults to False. + + Returns: + None + """ + if not self.is_data_distributed: + return + + self.temporary_bucket_allocator.free(self.bucket_index.bucket_id) + if and_free_params_data: + if self.placeholder_bucket is None: + self.placeholder_bucket = Bucket( + data=torch.empty(self.bucket_index.size, dtype=self.dtype, device=self.device) + ) + for p in self.params: + item_id = self.param_idx[p] + self.placeholder_items[item_id] = self.get_item_from_bucket( + self.placeholder_bucket, item_id + ).view(p.shape) + _free_storage(self.placeholder_bucket.data) + for p in self.params: + item_id = self.param_idx[p] + if is_float8tensor(p): + p._data = self.placeholder_items[item_id] + else: + p.data = self.placeholder_items[item_id] + + def _get_item_slice_in_shard(self, item_id: int) -> Tuple[int, int]: + item_index = self.item_index_map[item_id] + shard_bucket_index = self.shard_bucket_index + + item_global_start = item_index.global_data_index + item_global_end = item_index.global_data_index + item_index.size + shard_bucket_start = shard_bucket_index.global_data_index + shard_bucket_end = shard_bucket_index.global_data_index + shard_bucket_index.size + + if item_global_start > shard_bucket_end or item_global_end < shard_bucket_start: + return (0, 0) + + start = max(item_global_start, shard_bucket_start) - item_global_start + end = min(item_global_end, shard_bucket_end) - item_global_start + + return (start, end) + + # pylint: disable=missing-function-docstring + def locate_item_in_global_item(self, item_id: int) -> Tuple[int, int]: + item_index = self.item_index_map[item_id] + if not self.is_data_distributed: + return (0, item_index.size) + + slice_start, slice_end = self._get_item_local_shard_index(item_id) + if slice_start == slice_end: + return (0, 0) + + local_shard_index_to_global_index_offset = ( + self.shard_bucket_index.global_data_index - self.shard_bucket_index.local_data_index + ) + slice_start += local_shard_index_to_global_index_offset + slice_end += local_shard_index_to_global_index_offset + return ( + slice_start - item_index.global_data_index, + slice_end - item_index.global_data_index, + ) + + def _get_item_local_shard_index(self, item_id: int) -> Tuple[int, int]: + slice_start, slice_end = self._get_item_slice_in_shard(item_id) + if slice_start == slice_end: + return (0, 0) + + item_index = self.item_index_map[item_id] + shard_bucket_index = self.shard_bucket_index + offset = ( + item_index.global_data_index + - shard_bucket_index.global_data_index + + shard_bucket_index.local_data_index + ) + + return (offset + slice_start, offset + slice_end) + + def _get_item_local_index(self, item_id: int) -> Tuple[int, int]: + if not self.is_data_distributed: + item_index = self.item_index_map[item_id] + return (item_index.global_data_index, item_index.global_data_index + item_index.size) + + return self._get_item_local_shard_index(item_id) + + def set_item(self, item_id: int, item_data: torch.Tensor) -> None: + """ + Update a tensor item managed by the `DataParallelBuffer` instance. + + The storage of the item is mapped to the communication bucket. + This method updates the item data and ensures consistency with the bucket. + + Args: + item_id (int): The ID of the tensor item to update. + item_data (torch.Tensor): The new data for the tensor item. + + Returns: + None + """ + if self.is_data_distributed: + slice_start, slice_end = self._get_item_slice_in_shard(item_id) + item_data = item_data.flatten()[slice_start:slice_end] + local_index_start, local_index_end = self._get_item_local_index(item_id) + shard = self.data[local_index_start:local_index_end] + if shard.numel() > 0: + shard.data.copy_(item_data.flatten()) + + def get_item(self, item_id: int, only_shard: bool = False) -> torch.Tensor: + """ + Retrieve a tensor item managed by the `DataParallelBuffer` instance. + + The storage of the item is mapped to the communication bucket. + If `only_shard` is True, returns only the shard of the item corresponding + to the current process. + Otherwise, returns the entire item. + + Args: + item_id (int): The ID of the tensor item to retrieve. + only_shard (bool, optional): Whether to return only the shard of the + item. Defaults to False. + + Returns: + torch.Tensor: The retrieved tensor item. + """ + if only_shard: + start, end = self._get_item_local_shard_index(item_id) + else: + start, end = self._get_item_local_index(item_id) + + return self.data[start:end] + + def get_item_from_bucket(self, bucket: Bucket, item_id: int): + """get item from bucket.""" + item_index = self.item_index_map[item_id] + bucket_index = self.bucket_index + start_index = item_index.global_data_index - bucket_index.global_data_index + end_index = start_index + item_index.size + item = bucket.data[start_index:end_index] + return item + + def get_shard_from_bucket(self, bucket: Bucket): + """Get the local sharding of the bucket.""" + shard_bucket_index = self.shard_bucket_index + offset = shard_bucket_index.bucket_data_index + shard_size = shard_bucket_index.size + shard = bucket.data[offset : offset + shard_size] + return shard + + def get_shard_from_local_buffer(self) -> torch.Tensor: + """Get the local sharding of the bucket.""" + index = self.shard_bucket_index + return self.data[index.local_data_index : index.local_data_index + index.size] + + +@dataclasses.dataclass +class ParameterGroup: + """ + A group of model parameters with associated metadata for data-parallel training. + + This dataclass encapsulates a list of PyTorch parameters and additional information + necessary for managing data-parallel operations, such as data type, gradient requirements, + and buffer assignments. + """ + + params: List[torch.nn.Parameter] + dtype: Optional[torch.dtype] = None + is_expert_param: bool = False + requires_grad: Optional[bool] = None + fsdp_unit_id: Optional[int] = None + data_parallel_world_size: Optional[int] = None + model_weight_buffer: Optional[DataParallelBuffer] = None + main_weight_buffer: Optional[DataParallelBuffer] = None + main_grad_buffer: Optional[DataParallelBuffer] = None + + +def _get_parameter_groups( + module: torch.nn.Module, policy: BucketingPolicy, meta_device_init_fp8_params: dict +): + """ + Get the parameter group for the given module and parameters. + """ + param_to_name = {p: name for name, p in module.named_parameters()} + fsdp_units = [] + if policy.fsdp_unit_modules: + param_to_id = {} + for i, p in enumerate(module.parameters()): + param_to_id[p] = i + fsdp_modules = [] + for m in module.modules(): + # Skip nested FSDP module. + if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules): + continue + if isinstance(m, tuple(policy.fsdp_unit_modules)): + fsdp_units.append([param_to_name[p] for p in m.parameters()]) + fsdp_modules.append(m) + + def _does_param_require_new_bucket(param): + """ + Split shared embedding parameters into separate bucket if using distributed + optimizer that makes use of reduce-scatters instead of all-reduces. + This ensures that the first and last pipeline stage partition optimizer state + for the shared embedding parameters the same way across DP replicas, allowing + the DP reduce-scatter to be before the embedding all-reduce. + """ + return ( + getattr(param, "shared_embedding", False) + and policy.data_parallel_sharding_strategy != "no_shard" + ) + + is_expert_parameter = lambda p: not getattr(p, 'allreduce', True) + + # Step 1: Group the parameters according to their execution order and attributes. + parameter_groups = [] + for name, param in module.named_parameters(): + param_attrs = dict( + dtype=( + "float8" + if is_float8tensor(param) or meta_device_init_fp8_params.get(name, False) + else param.dtype + ), + is_expert_param=is_expert_parameter(param), + requires_grad=param.requires_grad, + fsdp_unit_id=None, + ) + for fsdp_unit_id, fsdp_unit in enumerate(fsdp_units): + if name in fsdp_unit: + param_attrs["fsdp_unit_id"] = fsdp_unit_id + break + + found_group = False + for param_group in parameter_groups: + group_attrs = { + key: value for key, value in param_group.__dict__.items() if key in param_attrs + } + if group_attrs == param_attrs: + param_group.params.append(param) + found_group = True + break + + if not found_group: + parameter_groups.append(ParameterGroup([param], **param_attrs)) + + # Step 2: Bucket the parameters based on the guide bucket size. + suggested_bucket_size = policy.suggested_bucket_size + bucket_groups = [] + for group in parameter_groups: + bucket = [] + + basic_attrs = { + key: value + for key, value in group.__dict__.items() + if key in ['dtype', 'is_expert_param', 'requires_grad', 'fsdp_unit_id'] + } + for param in group.params: + if _does_param_require_new_bucket(param): + if len(bucket) > 0: + bucket_groups.append(ParameterGroup(bucket, **basic_attrs)) + bucket_groups.append(ParameterGroup([param], **basic_attrs)) + bucket = [] + continue + + bucket.append(param) + if ( + group.fsdp_unit_id is None + and suggested_bucket_size + and sum([p.numel() for p in bucket]) >= suggested_bucket_size + ): + bucket_groups.append(ParameterGroup(bucket, **basic_attrs)) + bucket = [] + continue + + if bucket: + bucket_groups.append(ParameterGroup(bucket, **basic_attrs)) + + param_to_param_group = {} + for group_id, group in enumerate(bucket_groups): + for param in group.params: + param_to_param_group[param] = group_id + + # Log buckets for all PP stages. + if ( + parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0 + and parallel_state.get_tensor_model_parallel_rank() == 0 + ): + log_strs = [] + log_strs.append(f'Number of parameter groups for FSDP: {len(bucket_groups)}') + for index, group in enumerate(bucket_groups): + numel = 0 + for param in group.params: + numel += param.numel() + log_strs.append( + f"Params for group {index+1} ({numel} elements, dtype {group.dtype}, " + f"has_weight_buffer: {group.model_weight_buffer is not None}, " + f"has_grad_buffer: {group.main_grad_buffer is not None}, " + f"has_main_weight_buffer: {group.main_weight_buffer is not None}):" + ) + for param in group.params: + log_strs.append(f'\t{param_to_name[param]}') + log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs)) + + return (bucket_groups, fsdp_units, param_to_param_group) + + +class ParamAndGradBuffer: + """A class that manages parameter grouping, buffer allocation, and + communication operations for data-parallel distributed training. + + This class provides functionality to: + 1. Group parameters based on their data types and communication group sizes + 2. Create contiguous buffers for model weights, gradients, and high-precision + main weights + 3. Handle parameter unsharding, gradient reduction, and weight + synchronization operations + + Key Features: + - Efficient parameter grouping based on data types and communication patterns + - Memory-efficient contiguous buffer allocation + - Support for mixed-precision training with main weights + - Distributed operations including parameters all-gather and gradients + reduce-scatter/all-reduce + - Synchronized weight updates between model and main weights + + Note: + This class is designed for distributed training scenarios where efficient + parameter management and communication are crucial for performance. + + Args: + ddp_config (DistributedDataParallelConfig): The distributed data parallel + configuration. + module (torch.nn.Module): The module whose parameters are to be grouped + and flatten. + bucketing_policy (BucketingPolicy): The bucketing policy. + data_parallel_group (torch.distributed.ProcessGroup): The data parallel group. + expert_data_parallel_group (Optional[torch.distributed.ProcessGroup]): + The expert data parallel group. + preserve_fp32_weights (bool): Whether to preserve FP32 weights. + grad_reduce_in_fp32 (bool): Whether to reduce gradients in FP32. + gradient_scaling_factor (Optional[float]): The gradient scaling factor. + expert_gradient_scaling_factor (Optional[float]): The expert gradient + scaling factor. + device (torch.device): The parameter and gradient buffer device. + only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad (bool): + Whether to only create the gradient buffer and main weight buffer + for parameters that require gradients. Default is True. + """ + + def __init__( + self, + ddp_config: DistributedDataParallelConfig, + module: torch.nn.Module, + bucketing_policy: BucketingPolicy, + data_parallel_group: torch.distributed.ProcessGroup, + expert_data_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + preserve_fp32_weights: bool = True, + grad_reduce_in_fp32: bool = True, + gradient_scaling_factor: Optional[float] = None, + expert_gradient_scaling_factor: Optional[float] = None, + device: torch.device = torch.device('cuda'), + only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad: bool = True, + reset_parameters_for_meta_device_init_module: bool = False, + ): + self.ddp_config = ddp_config + self.module = module + self.bucketing_policy = bucketing_policy + self.param_to_name = {p: name for name, p in self.module.named_parameters()} + self.preserve_fp32_weights = preserve_fp32_weights + self.grad_reduce_in_fp32 = grad_reduce_in_fp32 + self.data_parallel_group = data_parallel_group + self.expert_data_parallel_group = expert_data_parallel_group + self.params = list(module.parameters()) + self.gradient_scaling_factor = gradient_scaling_factor + self.expert_gradient_scaling_factor = expert_gradient_scaling_factor + self.device = device + self.only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad = ( + only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad + ) + self.reset_parameters_for_meta_device_init_module = ( + reset_parameters_for_meta_device_init_module + ) + + # Mark fp8 param. + meta_device_init_fp8_params = {} + if reset_parameters_for_meta_device_init_module: + for m in module.modules(): + if not isinstance(m, TransformerEngineBaseModule): + continue + for name, param in m.named_parameters(recurse=False): + # The fp8 param initialized from the meta device may NOT be + # an fp8 tensor, according to the internal logic of the TE + # to determine whether this parameter is fp8 or not. + fp8_meta_index = m.param_init_meta[name].fp8_meta_index + if m.primary_weights_in_fp8 and fp8_meta_index is not None: + meta_device_init_fp8_params[self.param_to_name[param]] = True + + # Get the parameter groups. + (self.parameter_groups, self.fsdp_units, self.param_to_param_group) = _get_parameter_groups( + module, bucketing_policy, meta_device_init_fp8_params + ) + + self._init_each_parameter_group_buffers(meta_device_init_fp8_params) + + # Initialize the optimizer named parameters. + self.optimizer_named_parameters = self._init_optimizer_named_parameters() + + def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): + """ + Initialize the buffers for each parameter group. + """ + data_parallel_sharding_strategy = self.ddp_config.data_parallel_sharding_strategy + if data_parallel_sharding_strategy == 'no_shard': + is_model_weight_buffer_distributed = False + is_main_weight_buffer_distributed = False + is_grad_buffer_distributed = False + elif data_parallel_sharding_strategy == 'optim': + is_model_weight_buffer_distributed = False + is_main_weight_buffer_distributed = True + is_grad_buffer_distributed = False + elif data_parallel_sharding_strategy == 'optim_grads': + is_model_weight_buffer_distributed = False + is_main_weight_buffer_distributed = True + is_grad_buffer_distributed = True + elif data_parallel_sharding_strategy == 'optim_grads_params': + is_model_weight_buffer_distributed = True + is_main_weight_buffer_distributed = True + is_grad_buffer_distributed = True + else: + raise ValueError( + f'Invalid data_parallel_sharding_strategy: {data_parallel_sharding_strategy}' + ) + + self.memory_allocator_for_model_weight_buffer = StorageResizeBasedBucketAllocator() + self.buffer_all_in_one = True + + preserve_fp32_weights = self.preserve_fp32_weights + grad_reduce_in_fp32 = self.grad_reduce_in_fp32 + buffer_size = {torch.float32: 0, torch.float16: 0, torch.bfloat16: 0, "float8": 0} + for group_id, group in enumerate(self.parameter_groups): + dp_group = ( + self.data_parallel_group + if not group.is_expert_param + else self.expert_data_parallel_group + ) + group.data_parallel_world_size = torch.distributed.get_world_size(group=dp_group) + gradient_scaling_factor = ( + self.gradient_scaling_factor + if not group.is_expert_param + else self.expert_gradient_scaling_factor + ) + one_param = group.params[0] + is_dtype_float8 = is_float8tensor(one_param) or meta_device_init_fp8_params.get( + self.param_to_name[one_param], False + ) + if is_dtype_float8: + param_dtype = torch.uint8 + grad_dtype = torch.bfloat16 + else: + param_dtype = group.params[0].dtype + grad_dtype = param_dtype + should_create_grad_buffer_or_main_weight_buffer = ( + not self.only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad + or group.requires_grad + ) + + # Initialize the model weight buffer. + if data_parallel_sharding_strategy != 'no_shard': + group.model_weight_buffer = DataParallelBuffer( + self.ddp_config, + group.params, + is_data_distributed=is_model_weight_buffer_distributed + and group.data_parallel_world_size > 1, + dtype=param_dtype, + device=self.device, + data_parallel_group=dp_group, + init_meta_only=True, + is_dtype_float8=is_dtype_float8, + temporary_bucket_allocator=self.memory_allocator_for_model_weight_buffer, + bucket_id=group_id, + ) + + # Initialize the main weight buffer. + if should_create_grad_buffer_or_main_weight_buffer and preserve_fp32_weights: + group.main_weight_buffer = DataParallelBuffer( + self.ddp_config, + group.params, + is_data_distributed=is_main_weight_buffer_distributed + and group.data_parallel_world_size > 1, + dtype=torch.float32, + device=self.device, + data_parallel_group=dp_group, + init_meta_only=True, + bucket_id=group_id, + ) + + # Initialize the main grad buffer. + if should_create_grad_buffer_or_main_weight_buffer: + group.main_grad_buffer = DataParallelBuffer( + self.ddp_config, + group.params, + is_data_distributed=is_grad_buffer_distributed + and group.data_parallel_world_size > 1, + dtype=torch.float32 if grad_reduce_in_fp32 else grad_dtype, + device=self.device, + data_parallel_group=dp_group, + init_meta_only=True, + is_dtype_float8=not grad_reduce_in_fp32 and grad_dtype is torch.uint8, + gradient_scaling_factor=gradient_scaling_factor, + bucket_id=group_id, + ) + if grad_reduce_in_fp32: + buffer_size[torch.float32] += group.main_grad_buffer.data_size + elif group.main_grad_buffer.is_dtype_float8: + buffer_size["float8"] += group.main_grad_buffer.data_size + else: + buffer_size[group.main_grad_buffer.dtype] += group.main_grad_buffer.data_size + + reset_context_args = {"init_param_with_fp8": self.ddp_config.fp8_param_gather} + module_reset_flag = {} + if self.reset_parameters_for_meta_device_init_module: + self.param_to_direct_module = {} + for name, m in self.module.named_modules(): + for p in m.parameters(recurse=False): + self.param_to_direct_module[p] = (name, m) + + meta_params_numel = 0 + cuda_params_numel = 0 + cpu_params_numel = 0 + for group in self.parameter_groups: + for p in group.params: + if p.is_meta: + meta_params_numel += p.numel() + elif p.device.type == 'cuda': + cuda_params_numel += p.numel() + else: + cpu_params_numel += p.numel() + log_str = ( + f"Meta params numel: {meta_params_numel / 1_000_000:.2f} M, " + f"CUDA params numel: {cuda_params_numel / 1_000_000:.2f} M, " + f"CPU params numel: {cpu_params_numel / 1_000_000:.2f} M" + ) + log_on_each_pipeline_stage(logger, logging.INFO, log_str) + + # Initialize the model weight buffer data of each parameter group. + for group in self.parameter_groups: + wbuf = group.model_weight_buffer + if wbuf: + wbuf.data = torch.empty(wbuf.data_size, dtype=wbuf.dtype, device=self.device) + bucket = wbuf.fetch_bucket() + mbuf = group.main_weight_buffer + if mbuf: + mbuf.data = torch.empty(mbuf.data_size, dtype=mbuf.dtype, device=self.device) + for item_id, p in enumerate(group.params): + if wbuf: + if self.reset_parameters_for_meta_device_init_module and p.is_meta: + m_name, m = self.param_to_direct_module[p] + if not module_reset_flag.get(m_name, False) and hasattr( + m, "reset_parameters" + ): + old_params = list(m.parameters(recurse=False)) + + # If the GPU memory over threshold, empty cache to leave + # some memory for initialization of the model on the + # CUDA device. + if check_gpu_memory(threshold=0.5): + gc.collect() + torch.cuda.empty_cache() + + m.to_empty(device=self.device, recurse=False) + if is_te_min_version("0.9.0") and not isinstance( + m, TransformerEngineBaseModule + ): + reset_context_args["with_cuda_rng_tracker"] = True + with ResetParametersContext(**reset_context_args): + m.reset_parameters() + module_reset_flag[m_name] = True + new_params = list(m.parameters(recurse=False)) + + self._reset_parameters(old_params, new_params) + p = group.params[item_id] + assert not p.is_meta, (self.param_to_name[p], module_reset_flag) + wbuf.set_item(item_id, p.data) + + # reset the parameter data to the buffer + old_param_data = p.data + new_param_data = wbuf.get_item_from_bucket(bucket, item_id).view(p.shape) + if is_float8tensor(p): + p._data = new_param_data + else: + p.data = new_param_data + assert old_param_data._base is None + p.data.detach().copy_(old_param_data) + del old_param_data + if mbuf: + if hasattr(p, 'get_high_precision_init_val'): + mbuf.set_item(item_id, p.get_high_precision_init_val()) + p.clear_high_precision_init_val() + else: + mbuf.set_item(item_id, p) + + if wbuf and wbuf.is_data_distributed: + """ + When MCore Custom FSDP `optim_grads_params` is enabled, + it is necessary to save the tensor local shard. This local shard is + accessible through the `fully_shard_param_local_shard` + attribute of the tensor. + + This attribute contains the local shard of the fully + sharded parameter, which is essential for correctly + saving and loading the model state when using + `optim_grads_params` with FSDP. + + Example: + >>> # Assuming `tensor` is a fully sharded parameter + >>> local_shard = tensor.fully_shard_param_local_shard + >>> # Save the local shard as needed + """ + local_shard = wbuf.get_item(item_id, only_shard=True) + local_shard.fsdp_shard_orig_param = p + p.fully_shard_param_local_shard = local_shard + p.fully_shard_param_local_index = wbuf.locate_item_in_global_item(item_id) + + def disable_shard_param_to_function(*unused): + """Prevents users from accessing the 'to' operation + on parameters after sharding. + + This restriction helps maintain data integrity and + proper sharding behavior by disabling direct 'to' + device/dtype operations on sharded parameters. + """ + raise RuntimeError( + "Your model is wrapped by MCore Custom FSDP. All " + "parameter dtypes and devices must be set before FSDP " + "wrapping. After FSDP wrapping, parameter storage " + "is sharded and you cannot modify parameter " + "dtypes or devices." + ) + + setattr(p, 'to', disable_shard_param_to_function) + + def disable_shard_param_cpu_function(*unused): + warnings.warn( + "The parameters are sharded by custom fsdp, " + "and no actual cpu operation is performed." + ) + return torch.empty([], device='cpu') + + setattr(p, 'cpu', disable_shard_param_cpu_function) + + if wbuf and wbuf.is_data_distributed: + wbuf.free_bucket_storage() + + # Allocate the main_weight buffer and main_grad buffer data in one buffer. + if self.buffer_all_in_one: + self.buffer = { + torch.float32: torch.empty( + buffer_size[torch.float32], dtype=torch.float32, device=self.device + ), + torch.float16: torch.empty( + buffer_size[torch.float16], dtype=torch.float16, device=self.device + ), + torch.bfloat16: torch.empty( + buffer_size[torch.bfloat16], dtype=torch.bfloat16, device=self.device + ), + "float8": torch.empty(buffer_size["float8"], dtype=torch.uint8, device=self.device), + } + offset = {torch.float32: 0, torch.float16: 0, torch.bfloat16: 0, "float8": 0} + + def _alloc(dtype, size): + if self.buffer_all_in_one: + if dtype == torch.uint8: + dtype = "float8" + data = self.buffer[dtype][offset[dtype] : offset[dtype] + size] + offset[dtype] += size + return data + return torch.empty(size, dtype=dtype, device=self.device) + + # Initialize the main grad buffer data of each parameter group. + for group in self.parameter_groups: + gbuf = group.main_grad_buffer + if not gbuf: + continue + gbuf.data = _alloc(gbuf.dtype, gbuf.data_size) + gbuf.data.zero_() + for item_id, p in enumerate(group.params): + p.fsdp_managed_main_grad = gbuf.get_item(item_id) + p._gbuf = gbuf + p._item_id = item_id + + def main_grad_getter(p): + # Make sure main_grad memory storage ready. + bucket = p._gbuf.fetch_bucket() + gbuf = p._gbuf + item_id = p._item_id + if bucket.status == GradBucketStatus.GRAD_REDUCING: + if bucket.data_operation_event: + bucket.data_operation_event.wait() + bucket.data_operation_event = None + # Here it is assumed that main_grad is taken out and do + # gradient accumulation and should not be freed up before + # gradient reduction. + bucket.status = GradBucketStatus.GRAD_ACCUMULATING + return gbuf.get_item_from_bucket(bucket, item_id).view(p.shape) + + setattr(p.__class__, 'main_grad', property(main_grad_getter)) + + if gbuf.is_data_distributed: + gbuf.free_bucket_storage() + + gc.collect() + torch.cuda.empty_cache() + + def _reset_parameters(self, old_params, new_params): + assert len(old_params) == len(new_params) + param_map = {} + for old_param, new_param in zip(old_params, new_params): + param_map[old_param] = new_param + self.param_to_name[new_param] = self.param_to_name[old_param] + del self.param_to_name[old_param] + + self.param_to_param_group[new_param] = self.param_to_param_group[old_param] + del self.param_to_param_group[old_param] + + self.param_to_direct_module[new_param] = self.param_to_direct_module[old_param] + del self.param_to_direct_module[old_param] + + for item_id, p in enumerate(self.params): + if p in param_map: + new_p = param_map[p] + self.params[item_id] = new_p + + for group in self.parameter_groups: + for item_id, p in enumerate(group.params): + if p not in param_map: + continue + new_p = param_map[p] + group.params[item_id] = new_p + for buf in [ + group.model_weight_buffer, + group.main_weight_buffer, + group.main_grad_buffer, + ]: + if buf is None: + continue + buf.param_idx[new_p] = buf.param_idx[p] + del buf.param_idx[p] + + def scale_gradients(self, scaling_factor: float) -> None: + """Scale the gradient data by `scaling_factor`.""" + for group in self.parameter_groups: + if group.main_grad_buffer is None: + continue + group.main_grad_buffer.data *= scaling_factor + + def zero_grad(self): + """ + Zero out the underlying grad_buffer and reset all buckets in preparation + for the next iteration of training. + """ + for _, param in self.optimizer_named_parameters: + if param.grad is not None and param.grad._base is None: + # For tensors that are not referenced, trying to use storage + # resize to make memory free immediately. + _free_storage(param.grad) + param.grad = None + + for group in self.parameter_groups: + if group.main_grad_buffer is None: + continue + group.main_grad_buffer.data.zero_() + + def _init_optimizer_named_parameters(self) -> List[Tuple[str, torch.nn.Parameter]]: + named_parameters = [] + for pg in self.parameter_groups: + if pg.main_grad_buffer is None: + continue + + optimizer_state_is_shard = pg.main_grad_buffer.is_data_distributed or ( + pg.main_weight_buffer and pg.main_weight_buffer.is_data_distributed + ) + for item_id, orig_param in enumerate(pg.params): + if pg.main_weight_buffer: + param = pg.main_weight_buffer.get_item( + item_id, only_shard=optimizer_state_is_shard + ) + elif pg.model_weight_buffer: + param = pg.model_weight_buffer.get_item( + item_id, only_shard=optimizer_state_is_shard + ) + else: + param = orig_param + + def set_param_attribute_closure(param, orig_param): + def set_param_attribute(): + for attr_name in [ + 'requires_grad', + 'sequence_parallel', + 'shared', + 'tensor_model_parallel', + 'partition_dim', + 'partition_stride', + 'is_embedding_or_output_parameter', + ]: + if hasattr(orig_param, attr_name): + setattr(param, attr_name, getattr(orig_param, attr_name)) + + return set_param_attribute + + setattr(param, 'reset_attribute', set_param_attribute_closure(param, orig_param)) + setattr(param, 'orig_param', orig_param) + param.reset_attribute() + named_parameters.append((self.param_to_name[orig_param], param)) + + return named_parameters + + def update_main_grads(self): + """Update the main gradients for preparing the optimizer step.""" + for _, param in self.optimizer_named_parameters: + param.reset_attribute() + orig_param = param.orig_param + group = self.parameter_groups[self.param_to_param_group[orig_param]] + item_id = group.main_grad_buffer.param_idx[orig_param] + optimizer_grad = group.main_grad_buffer.get_item( + item_id, only_shard=group.main_weight_buffer.is_data_distributed + ) + setattr( + param, + 'grad', + optimizer_grad.to(param.dtype) if optimizer_grad.numel() > 0 else None, + ) + + @property + def num_buckets(self): + """Return the number of buckets.""" + return len(self.parameter_groups) + + @torch.no_grad() + def copy_main_weights_to_model_weights(self): + """Update the model weights from the main weights.""" + for pg in self.parameter_groups: + mbuf = pg.main_weight_buffer + wbuf = pg.model_weight_buffer + if mbuf is None: + continue + + for param in pg.params: + item_id = mbuf.param_idx[param] + if wbuf: + if wbuf.is_data_distributed or mbuf.is_data_distributed: + model_param = wbuf.get_item(item_id, only_shard=True) + main_weight = mbuf.get_item(item_id, only_shard=True) + else: + model_param = wbuf.get_item(item_id) + main_weight = mbuf.get_item(item_id) + else: + assert not mbuf.is_data_distributed + model_param = param + main_weight = pg.main_weight_buffer.get_item(item_id) + + if model_param.numel() == 0: + continue + + if is_float8tensor(param): + # 1. When "--fp8-param-gather" is disabled, the main param + # is first casted to BF16/FP16, and then casted to FP8, so + # the amax_history is calculated using BF16/FP16 param. + # 2. When "--fp8-param-gather" is enabled, we can cast the + # FP32 main param to FP8 directly, which results in slightly + # different results with higher performance. In theory, this + # does not affect convergence. + # TODO: The following code maintains the logic of the point-1 + # above. It can be deleted if it is not necessary. + main_weight = main_weight.to(param.dtype) + cast_to_fp8( + main_weight.view(1, -1), + param._fp8_meta['scaling_fwd'], + param._fp8_meta_index, + param._fp8_dtype, + out=model_param.view(1, -1), + ) + else: + model_param.data.copy_(main_weight.view(model_param.shape)) + + @torch.no_grad() + def copy_model_weights_to_main_weights(self): + """Copy the model weights to the main weights.""" + for group in self.parameter_groups: + mbuf = group.main_weight_buffer + if mbuf is None: + continue + wbuf = group.model_weight_buffer + if mbuf.is_data_distributed: + copyin_data = wbuf.get_shard_from_local_buffer() + else: + copyin_data = wbuf.data + assert mbuf.data.numel() == copyin_data.numel(), ( + f"Master weight buffer size {mbuf.data.numel()} does not match " + f"model weight buffer size {copyin_data.numel()}" + ) + mbuf.data.copy_(copyin_data.data) + + def all_gather_parameters(self, async_op: bool = True): + """All gather the parameters. + Args: + async_op (bool, optional): Whether to do the all-reduce + asynchronously. Defaults to False. + """ + assert all( + [not g.model_weight_buffer.is_data_distributed for g in self.parameter_groups] + ), 'all_gather_parameters() should only be called when parameters are not sharded.' + + all_gather_ops = [] + for g in self.parameter_groups: + shard = g.model_weight_buffer.get_shard_from_local_buffer() + all_gather_handler = torch.distributed.all_gather_into_tensor( + output_tensor=g.model_weight_buffer.data, + input_tensor=shard, + group=g.model_weight_buffer.data_parallel_group, + async_op=async_op, + ) + if async_op: + all_gather_ops.append(all_gather_handler) + + for op in all_gather_ops: + op.wait() + + def reduce_scatter_gradients(self, async_op: bool = True): + """Reduce scatter the gradients. + Args: + async_op (bool, optional): Whether to do the all-reduce + asynchronously. Defaults to False. + """ + assert all( + [not g.main_grad_buffer.is_data_distributed for g in self.parameter_groups] + ), 'reduce_scatter_gradients() should only be called when gradients are not sharded.' + + reduce_scatter_ops = [] + for g in self.parameter_groups: + gbuf = g.main_grad_buffer + if gbuf is not None: + continue + scaling_factor = gbuf.gradient_scaling_factor + reduce_op = gradient_reduce_preprocessing(gbuf.data, scaling_factor, self.ddp_config) + reduce_scatter_handler = torch.distributed.reduce_scatter_tensor( + output=gbuf.get_shard_from_local_buffer(), + input=gbuf.data, + op=reduce_op, + group=g.main_grad_buffer.data_parallel_group, + async_op=async_op, + ) + + if async_op: + reduce_scatter_ops.append(reduce_scatter_handler) + + for op in reduce_scatter_ops: + op.wait() + + def all_reduce_gradients(self, async_op: bool = False): + """All reduce the gradients. + Args: + async_op (bool, optional): Whether to do the all-reduce + asynchronously. Defaults to False. + """ + assert all( + [ + not g.main_grad_buffer.is_data_distributed + for g in self.parameter_groups + if g.main_grad_buffer + ] + ), 'all_reduce_gradients() should only be called when gradients are not sharded.' + + all_reduce_ops = [] + for g in self.parameter_groups: + gbuf = g.main_grad_buffer + if gbuf is not None: + continue + scaling_factor = gbuf.gradient_scaling_factor + reduce_op = gradient_reduce_preprocessing(gbuf.data, scaling_factor, self.ddp_config) + all_reduce_handler = torch.distributed.all_reduce( + gbuf.data, op=reduce_op, group=gbuf.data_parallel_group, async_op=async_op + ) + if async_op: + all_reduce_ops.append(all_reduce_handler) + + for op in all_reduce_ops: + op.wait() + + +class BucketStatus(Enum): + """ + An enumeration of possible statuses for a data-parallel communication bucket. + + Attributes: + EMPTY (int): The bucket is empty and not in use. + COMMUNICATING (int): The bucket is currently being used for communication. + READY_TO_USE (int): The bucket is filled with data and ready for use. + """ + + EMPTY = 1 + COMMUNICATING = 2 + READY_TO_USE = 3 + + +class GradBucketStatus(Enum): + """ + An enumeration of possible statuses for a gradient bucket. + + Attributes: + GRAD_ACCUMULATING (int): The gradient bucket is currently accumulating gradients. + GRAD_REDUCING (int): The gradient bucket is currently reducing gradients. + """ + + GRAD_ACCUMULATING = 1 + GRAD_REDUCING = 2 + + +class GradReducePipeline: + """ + Pipeline for reducing gradients. + """ + + def __init__( + self, + param_and_grad_buffer: ParamAndGradBuffer, + cuda_stream: Optional[torch.cuda.Stream] = None, + check_nans: bool = False, + ) -> None: + self.buffer = param_and_grad_buffer + self.grad_reduce_queue = [] + self.bucket_status = { + i: BucketStatus.EMPTY + for i in range(self.buffer.num_buckets) + if self.buffer.parameter_groups[i].main_grad_buffer + } + self.buckets = {} + self.cuda_stream = cuda_stream + self.check_nans = check_nans + + @property + def num_buckets(self): + """Return the number of buckets.""" + return self.buffer.num_buckets + + def reset(self): + """Reset the pipeline state.""" + assert len(self.grad_reduce_queue) == 0, ( + f"There are still pending reduce-scatter tasks, it is not safe to reset. " + f"items: {self.grad_reduce_queue.keys()}, bucket_status: {self.bucket_status}." + ) + for bucket_id, _ in self.bucket_status.items(): + gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer + gbuf.free_bucket_storage() + self.bucket_status[bucket_id] = BucketStatus.EMPTY + assert all([status is BucketStatus.EMPTY for status in self.bucket_status.values()]), ( + f"There are still pending buckets, it is not safe to reset. " + f"bucket_status: {self.bucket_status}." + ) + + self.buckets = {} + + def place_bucket(self, bucket_id: int) -> bool: + """Place a full size bucket by bucket id. + Args: + bucket_id (int): The bucket id. + Returns: + bool: True if the bucket is placed successfully. + """ + assert bucket_id in self.bucket_status, f"Bucket {bucket_id} is not in the bucket status." + bucket_status = self.bucket_status[bucket_id] + if bucket_status == BucketStatus.READY_TO_USE: + return False + if bucket_status == BucketStatus.COMMUNICATING: + self.wait_for_previous_grad_reduce(0) + + assert bucket_id not in self.buckets, f"Bucket {bucket_id} is already allocated." + + gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer + bucket = gbuf.fetch_bucket() + requires_grad_items = sum([p.requires_grad for p in gbuf.params]) + setattr(bucket, 'requires_grad_items', requires_grad_items) + setattr(bucket, 'items', []) + + self.buckets[bucket_id] = bucket + self.bucket_status[bucket_id] = BucketStatus.READY_TO_USE + return True + + def wait_for_previous_grad_reduce( + self, recommeded_queue_size: int = 1, recommeded_queue_capacity: Optional[int] = None + ): + """ + Wait for the previous reduce-scatter/all-reduce to finish. + Args: + recommeded_queue_size (int, optional): The recommended queue size. Defaults to 1. + recommeded_queue_capacity (Optional[int], optional): The recommended queue capacity. + Defaults to None. + """ + if recommeded_queue_capacity is not None: + queue_space = sum( + [ + self.buffer.parameter_groups[bucket_id].main_grad_buffer.bucket_index.size + for _, _, bucket_id in self.grad_reduce_queue + ] + ) + while queue_space > recommeded_queue_capacity: + grad_reduce_event, free_up_grad_bucket, bucket_id = self.grad_reduce_queue.pop(0) + grad_reduce_event.wait() + free_up_grad_bucket() + queue_space -= self.buffer.parameter_groups[ + bucket_id + ].main_grad_buffer.bucket_index.size + else: + recommeded_queue_size = max(0, min(recommeded_queue_size, self.buffer.num_buckets - 1)) + while len(self.grad_reduce_queue) > recommeded_queue_size: + grad_reduce_event, free_up_grad_bucket, _ = self.grad_reduce_queue.pop(0) + grad_reduce_event.wait() + free_up_grad_bucket() + + def mark_item_ready(self, item: torch.Tensor, async_rs: bool = False) -> bool: + """Mark the item ready for reduce-scatter/all-reduce. + Args: + item (torch.Tensor): The item to be marked. + async_rs (bool, optional): Whether to do the reduce-scatter/all-reduce + asynchronously. Defaults to False. + Returns: + bool: True if the item is go for reduce-scatter/all-reduce. + """ + bucket_id = self.buffer.param_to_param_group[item] + assert bucket_id in self.buckets, f"Bucket {bucket_id} is not allocated." + + scaling_factor = self.buffer.gradient_scaling_factor + bucket = self.buckets[bucket_id] + bucket.items.append(item) + assert len(bucket.items) <= bucket.requires_grad_items, "Too many items in the bucket." + if len(bucket.items) != bucket.requires_grad_items: + return False + + self.bucket_status[bucket_id] = BucketStatus.COMMUNICATING + + current_stream = torch.cuda.current_stream() + reduce_scatter_stream = ( + self.cuda_stream if self.cuda_stream is not None else torch.cuda.current_stream() + ) + reduce_scatter_stream.wait_stream(current_stream) + with torch.cuda.stream(reduce_scatter_stream): + gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer + scaling_factor = gbuf.gradient_scaling_factor + reduce_op = gradient_reduce_preprocessing(gbuf.data, scaling_factor, gbuf.ddp_config) + if gbuf.ddp_config.data_parallel_sharding_strategy == 'no_shard': + torch.distributed.all_reduce( + bucket.data, op=reduce_op, group=gbuf.data_parallel_group + ) + else: + grad_shard = gbuf.get_shard_from_bucket(bucket) + grad_shard = torch.empty_like(grad_shard) + torch.distributed.reduce_scatter_tensor( + output=grad_shard, + input=bucket.data, + op=reduce_op, + group=gbuf.data_parallel_group, + ) + if gbuf.is_data_distributed: + # Gradient accumulate on local buffer + local_buffer = gbuf.get_shard_from_local_buffer() + local_buffer += grad_shard + reduce_scatter_view_out_event = reduce_scatter_stream.record_event() + bucket.data_operation_event = reduce_scatter_view_out_event + bucket.status = GradBucketStatus.GRAD_REDUCING + del self.buckets[bucket_id] + + def get_closure(): + def free_up_grad_bucket(): + nonlocal gbuf, local_buffer, bucket_id, bucket + if self.check_nans: + assert not torch.isnan( + local_buffer + ).any(), f"NaN detected in bucket {bucket_id}: {local_buffer}" + + # There is a special case where this bucket is taken for + # gradient accumulating before it has a chance to be free-up (here), + # in which case we free-up here because there is still + # subsequent gradient reducing to be done on this bucket. + if gbuf.is_data_distributed and bucket.status != GradBucketStatus.GRAD_ACCUMULATING: + gbuf.free_bucket_storage() + self.bucket_status[bucket_id] = BucketStatus.EMPTY + + return free_up_grad_bucket + + free_up_grad_bucket = get_closure() + + if async_rs: + self.grad_reduce_queue.append( + (reduce_scatter_view_out_event, free_up_grad_bucket, bucket_id) + ) + return True + + free_up_grad_bucket() + + return True + + +class PrefetchOrder(Enum): + """ + An enumeration of possible prefetch orders for data-parallel operations. + + Attributes: + FORWARD_PASS_ORDER (int): Prefetch in the order of forward pass computation. + BACKWARD_PASS_ORDER (int): Prefetch in the order of backward pass computation. + """ + + FORWARD_PASS_ORDER = 0 + BACKWARD_PASS_ORDER = 1 + + +class AllGatherPipeline: + """ + Pipeline for all-gathering parameters. + """ + + def __init__(self, param_and_grad_buffer: ParamAndGradBuffer) -> None: + self.buffer = param_and_grad_buffer + self.param_gather_event_map = {} + self.bucket_status = {i: BucketStatus.EMPTY for i in range(self.buffer.num_buckets)} + self.bucket_can_be_released = {i: False for i in range(self.buffer.num_buckets)} + + @property + def num_buckets(self): + """Return the number of buckets.""" + return self.buffer.num_buckets + + def reset(self): + """Reset the pipeline state.""" + if len(self.param_gather_event_map) > 0: + warnings.warn( + "There are still pending all-gather tasks, process them." + f"Bucket status: {self.bucket_status}.", + UserWarning, + ) + while len(self.param_gather_event_map) > 0: + bucket_id = next(iter(self.param_gather_event_map)) + self.wait_bucket_ready(bucket_id) + for bucket_id in self.bucket_can_be_released: + self.bucket_can_be_released[bucket_id] = True + self.recycle_unused_buckets() + + assert all([status is BucketStatus.EMPTY for status in self.bucket_status.values()]), ( + f"There are still working buckets, it is not safe to reset. " + f"bucket_status: {self.bucket_status}." + ) + assert all( + [not can_be_released for can_be_released in self.bucket_can_be_released.values()] + ), ( + f"The bucket can be released table is in an abnormal state, not safe to reset. " + f"bucket_can_be_released: {self.bucket_can_be_released}." + ) + + def queue_bucket_to_all_gather( + self, + bucket_id: int, + prefetch: bool = False, + prefetch_order: PrefetchOrder = PrefetchOrder.FORWARD_PASS_ORDER, + suggested_AG_prefetch_size: Optional[int] = None, + ): + """Performs an asynchronous all-gather operation by queuing the task bucket into + a dedicated queue (NCCL CUDA Stream). + + This function is a part of FSDP (Fully Sharded Data Parallel) + implementation that handles the all-gather operation in a queue-based + manner. Instead of executing the all-gather immediately, it enqueues + the operation into a task queue, which helps manage system resources and + prevents overwhelming the GPU memory and communication bandwidth. + + The queued all-gather operation will: + * Collect distributed sharded parameters from all participating processes + * Reconstruct the full parameter tensor + + Args: + bucket_id (int): The bucket ID to be queued for all-gathering. + prefetch (bool, optional): Whether to prefetch the next bucket. Defaults to False. + prefetch_order (PrefetchOrder, optional): The order of prefetching. + Defaults to PrefetchOrder.FORWARD_PASS_ORDER. + suggested_AG_prefetch_size (Optional[int], optional): + The suggested prefetch size for all-gathering. Defaults to None. + """ + parameter_groups = self.buffer.parameter_groups + ag_buckets = [bucket_id] + + # If prefetch is enabled, we will add prefetch buckets to ag_buckets. + if prefetch: + if suggested_AG_prefetch_size is not None: + all_gather_size = parameter_groups[bucket_id].model_weight_buffer.bucket_index.size + while all_gather_size < suggested_AG_prefetch_size: + if prefetch_order == PrefetchOrder.FORWARD_PASS_ORDER: + next_bucket_id = bucket_id + 1 + else: + next_bucket_id = bucket_id - 1 + if next_bucket_id < 0 or next_bucket_id >= self.buffer.num_buckets: + break + + next_group = parameter_groups[next_bucket_id] + ag_buckets.append(next_bucket_id) + + all_gather_size += next_group.model_weight_buffer.bucket_index.size + bucket_id = next_bucket_id + else: + if prefetch_order == PrefetchOrder.FORWARD_PASS_ORDER: + next_bucket_id = bucket_id + 1 + else: + next_bucket_id = bucket_id - 1 + if next_bucket_id >= 0 and next_bucket_id < self.buffer.num_buckets: + ag_buckets.append(next_bucket_id) + + # Launch all-gather operations for all buckets in ag_buckets. + for bucket_id in ag_buckets: + self.all_gather_bucket_and_set_items(bucket_id, async_op=True) + + def wait_bucket_ready(self, bucket_id, empty_ok=False): + """Wait for the bucket to be ready.""" + if self.bucket_status[bucket_id] == BucketStatus.READY_TO_USE: + return + if self.bucket_status[bucket_id] == BucketStatus.EMPTY: + if empty_ok: + return + raise ValueError(f"Bucket {bucket_id} is empty.") + + param_gather_event, mark_bucket_ready_to_use = self.param_gather_event_map.pop(bucket_id) + param_gather_event.wait() + mark_bucket_ready_to_use() + + @torch.no_grad() + def release_bucket(self, bucket_id: int): + """Release the bucket.""" + if self.bucket_status[bucket_id] == BucketStatus.EMPTY: + return + + if self.bucket_status[bucket_id] == BucketStatus.COMMUNICATING: + raise ValueError(f"Bucket {bucket_id} is communicating.") + + wbuf = self.buffer.parameter_groups[bucket_id].model_weight_buffer + wbuf.free_bucket_storage() + self.bucket_status[bucket_id] = BucketStatus.EMPTY + + def recycle_unused_buckets(self): + """Recycle the unused buckets.""" + for bucket_id, can_be_released in self.bucket_can_be_released.items(): + if can_be_released: + self.release_bucket(bucket_id) + self.bucket_can_be_released[bucket_id] = False + + @torch.no_grad() + def all_gather_bucket_and_set_items(self, bucket_id: int, async_op: bool = False) -> None: + """All-gather the bucket and set the items.""" + self.bucket_can_be_released[bucket_id] = False + if self.bucket_status[bucket_id] != BucketStatus.EMPTY: + return + + self.bucket_status[bucket_id] = BucketStatus.COMMUNICATING + wbuf = self.buffer.parameter_groups[bucket_id].model_weight_buffer + + # Lazy release the unused buckets. + self.recycle_unused_buckets() + bucket = wbuf.fetch_bucket(and_allocate_params_data=True) + param_gather_event = torch.distributed.all_gather_into_tensor( + output_tensor=bucket.data, + input_tensor=wbuf.get_shard_from_local_buffer(), + group=wbuf.data_parallel_group, + async_op=async_op, + ) + + def get_closure(): + @torch.no_grad() + def mark_bucket_ready_to_use(): + nonlocal wbuf, bucket_id + self.bucket_status[bucket_id] = BucketStatus.READY_TO_USE + + return mark_bucket_ready_to_use + + mark_bucket_ready_to_use = get_closure() + + if async_op: + self.param_gather_event_map[bucket_id] = (param_gather_event, mark_bucket_ready_to_use) + return + mark_bucket_ready_to_use() + + +@torch.no_grad() +def gradient_reduce_preprocessing(grad_data, scaling_factor, ddp_config): + """ + Gradient reduce preprocessing for gradient averaging and gradient scaling. + """ + + if scaling_factor is None: + reduce_op = torch.distributed.ReduceOp.SUM + elif ddp_config.average_in_collective: + reduce_op = torch.distributed.ReduceOp.AVG + elif ddp_config.gradient_reduce_div_fusion and grad_data.dtype != torch.bfloat16: + reduce_op = torch.distributed._make_nccl_premul_sum(scaling_factor) + else: + grad_data.mul_(scaling_factor) + reduce_op = torch.distributed.ReduceOp.SUM + + return reduce_op + + +def check_gpu_memory(threshold=0.9): + """ + Check if the GPU memory is over the threshold. + Args: + threshold (float, optional): The threshold to check if the GPU memory is over. + Defaults to 0.9. + Returns: + bool: True if the GPU memory is over the threshold. + """ + if not torch.cuda.is_available(): + return False + device = torch.cuda.current_device() + allocated = torch.cuda.memory_allocated(device) + reserved = torch.cuda.memory_reserved(device) + total = torch.cuda.get_device_properties(device).total_memory + + allocated_ratio = allocated / total + reserved_ratio = reserved / total + + near_full = allocated_ratio >= threshold or reserved_ratio >= threshold + + if near_full: + log_on_each_pipeline_stage( + logger, + logging.INFO, + f"GPU Memory: Allocated: {allocated_ratio:.2%}, Reserved: {reserved_ratio:.2%}", + ) + return near_full + + +class ResetParametersContext: + """ + Context manager for resetting parameters for meta device initialization module. + """ + + def __init__(self, init_param_with_fp8=False, with_cuda_rng_tracker=False): + self.init_param_with_fp8 = init_param_with_fp8 + self.with_cuda_rng_tracker = with_cuda_rng_tracker + + def __enter__(self): + self.stack = ExitStack() + if self.init_param_with_fp8: + args = {"enabled": True} + if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters: + args["preserve_high_precision_init_val"] = True + self.stack.enter_context(fp8_model_init(**args)) + + if self.with_cuda_rng_tracker: + self.stack.enter_context(get_cuda_rng_tracker().fork()) + + return self + + def __exit__(self, *exc_details): + self.stack.__exit__(*exc_details) diff --git a/megatron/core/distributed/data_parallel_base.py b/megatron/core/distributed/data_parallel_base.py index aed576a7a35d8d070a5162dc6c7776ad1849ad93..24ab89429bccd3776120608427de597e27c465d2 100644 --- a/megatron/core/distributed/data_parallel_base.py +++ b/megatron/core/distributed/data_parallel_base.py @@ -1,96 +1,96 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -from contextlib import contextmanager - -import torch - -from ..transformer.module import MegatronModule -from ..transformer.transformer_config import TransformerConfig - - -class _BaseDataParallel(MegatronModule): - """A template class for DistributedDataParallel implementations.""" - - def __init__(self, config: TransformerConfig, module: torch.nn.Module): - super().__init__(config=config) - self.module = module - - def forward(self, *inputs, **kwargs): - """ - Calls the wrapped module's forward() method. - """ - return self.module(*inputs, **kwargs) - - @contextmanager - def no_sync(self): - """ - Context manager that turns off gradient synchronization. - """ - try: - yield - finally: - pass - - def start_grad_sync(self, *unused): - """ - Initiates grad sync (all-reduce or reduce-scatter) communication operations - for all model gradients. - - When overlap_grad_reduce is set to True, dispatches asynchronous communication - calls. When overlap_grad_reduce is set to False, calls synchronous - communication ops. - """ - pass - - def scale_gradients(self, scaling_factor: float) -> None: - """Scale all gradients inside the buffers by `scaling_factor`.""" - pass - - def finish_grad_sync(self): - """ - Finishes grad sync (all-reduce or reduce-scatter) communication operations - for all model gradients. - - When overlap_grad_reduce is set to True, waits for asynchronous communication - calls to complete. When overlap_grad_reduce is set to False, calls synchronous - communication ops. - """ - pass - - def zero_grad_buffer(self): - """ - Zeros out all grad buffers. Needs to be called at the beginning of each - training iteration. - """ - pass - - def broadcast_params(self): - """ - Syncs parameters across all DP ranks. - """ - pass - - def state_dict(self, prefix='', keep_vars=False): - """ - Returns a dictionary containing references to the whole state of the - wrapped module. - - Both parameters and persistent buffers (e.g. running averages) are included. - Keys are corresponding parameter and buffer names. Parameters and buffers - set to None are not included. - """ - return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """ - Returns wrapped module's state_dict for checkpoint saving. - """ - return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) - - def load_state_dict(self, state_dict, strict=True): - """ - Copies parameters and buffers from state_dict into the wrapped module and its - descendants. If strict is True, then the keys of state_dict must exactly match - the keys returned by this module’s state_dict() function. - """ - self.module.load_state_dict(state_dict, strict=strict) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from contextlib import contextmanager + +import torch + +from ..transformer.module import MegatronModule +from ..transformer.transformer_config import TransformerConfig + + +class _BaseDataParallel(MegatronModule): + """A template class for DistributedDataParallel implementations.""" + + def __init__(self, config: TransformerConfig, module: torch.nn.Module): + super().__init__(config=config) + self.module = module + + def forward(self, *inputs, **kwargs): + """ + Calls the wrapped module's forward() method. + """ + return self.module(*inputs, **kwargs) + + @contextmanager + def no_sync(self): + """ + Context manager that turns off gradient synchronization. + """ + try: + yield + finally: + pass + + def start_grad_sync(self, *unused): + """ + Initiates grad sync (all-reduce or reduce-scatter) communication operations + for all model gradients. + + When overlap_grad_reduce is set to True, dispatches asynchronous communication + calls. When overlap_grad_reduce is set to False, calls synchronous + communication ops. + """ + pass + + def scale_gradients(self, scaling_factor: float) -> None: + """Scale all gradients inside the buffers by `scaling_factor`.""" + pass + + def finish_grad_sync(self): + """ + Finishes grad sync (all-reduce or reduce-scatter) communication operations + for all model gradients. + + When overlap_grad_reduce is set to True, waits for asynchronous communication + calls to complete. When overlap_grad_reduce is set to False, calls synchronous + communication ops. + """ + pass + + def zero_grad_buffer(self): + """ + Zeros out all grad buffers. Needs to be called at the beginning of each + training iteration. + """ + pass + + def broadcast_params(self): + """ + Syncs parameters across all DP ranks. + """ + pass + + def state_dict(self, prefix='', keep_vars=False, destination=None): + """ + Returns a dictionary containing references to the whole state of the + wrapped module. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. Parameters and buffers + set to None are not included. + """ + return self.module.state_dict(prefix=prefix, keep_vars=keep_vars, destination=destination) + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """ + Returns wrapped module's state_dict for checkpoint saving. + """ + return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict, strict=True): + """ + Copies parameters and buffers from state_dict into the wrapped module and its + descendants. If strict is True, then the keys of state_dict must exactly match + the keys returned by this module’s state_dict() function. + """ + self.module.load_state_dict(state_dict, strict=strict) diff --git a/megatron/core/distributed/distributed_data_parallel_config.py b/megatron/core/distributed/distributed_data_parallel_config.py index fbcd930191377ff24b47b8bd2290d9f395c6ebd7..fe7cec5874e8d5ad7973c087e9ae24b441170178 100644 --- a/megatron/core/distributed/distributed_data_parallel_config.py +++ b/megatron/core/distributed/distributed_data_parallel_config.py @@ -1,49 +1,78 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class DistributedDataParallelConfig: - """Configuration for DistributedDataParallel.""" - - grad_reduce_in_fp32: bool = False - """If true, reduce grads in fp32.""" - - overlap_grad_reduce: bool = False - """If true, overlap grad all-reduce / reduce-scatter with backward compute.""" - - overlap_param_gather: bool = False - """If true, overlap param all-gather with forward compute.""" - - align_param_gather: bool = False - """If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each - PP stage will independently launch as needed. - """ - - use_distributed_optimizer: bool = False - """If true, issue reduce-scatter collectives to aggregate gradients and clean up - originally allocated model parameters, otherwise issue all-reduce collectives. - """ - - num_distributed_optimizer_instances: int = 1 - """Sets the factor by which the DP domain is sharded to have the partial DistOpt - enabled. Defaults to 1, which means DistOpt is across entire DP domain. - """ - - check_for_nan_in_grad: bool = False - """ If true, check for NaNs in gradients _before_ communication collective.""" - - bucket_size: Optional[int] = None - """Maximum number of parameters in each bucket. If unspecified, MCore uses a default - value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger - buckets to ensure collectives do not become latency-bound).""" - - average_in_collective: bool = False - """If true, compute average in collective directly, as opposed to dividing by the - dp_size first and then computing sum in the collective.""" - - fp8_param_gather: bool = False - """If true, keep the compute param in fp8 (do not use any other intermediate dtype) and - perform the param all-gather in fp8.""" +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class DistributedDataParallelConfig: + """Configuration for DistributedDataParallel.""" + + grad_reduce_in_fp32: bool = False + """If true, reduce grads in fp32.""" + + overlap_grad_reduce: bool = False + """If true, overlap grad all-reduce / reduce-scatter with backward compute.""" + + overlap_param_gather: bool = False + """If true, overlap param all-gather with forward compute.""" + + align_param_gather: bool = False + """If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each + PP stage will independently launch as needed. + """ + + use_distributed_optimizer: bool = False + """If true, issue reduce-scatter collectives to aggregate gradients and clean up + originally allocated model parameters, otherwise issue all-reduce collectives. + """ + + num_distributed_optimizer_instances: int = 1 + """Sets the factor by which the DP domain is sharded to have the partial DistOpt + enabled. Defaults to 1, which means DistOpt is across entire DP domain. + """ + + check_for_nan_in_grad: bool = False + """If true, check for NaNs and Infs in gradients _before_ communication collective.""" + + check_for_large_grads: bool = False + """If true, check for unexpectedly large gradients _before_ communication collective.""" + + bucket_size: Optional[int] = None + """Maximum number of parameters in each bucket. If unspecified, MCore uses a default + value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger + buckets to ensure collectives do not become latency-bound).""" + + pad_buckets_for_high_nccl_busbw: bool = False + """If true, make sure the bucket size is divisible by a large power of 2 (2^16) to + ensure NCCL collectives have high bus bandwidth at large DP counts, since NCCL + message size (which for ring algorithms is bucket_size / dp_size) apparently needs + to be divisible by a power of 2 for high busbw.""" + + average_in_collective: bool = False + """If true, compute average in collective directly, as opposed to dividing by the + dp_size first and then computing sum in the collective.""" + + fp8_param_gather: bool = False + """If true, keep the compute param in fp8 (do not use any other intermediate dtype) and + perform the param all-gather in fp8.""" + + use_custom_fsdp: bool = False + """If true, use the FSDP code path for DDP.""" + + data_parallel_sharding_strategy: str = 'no_shard' + """Sharding strategy for FSDP. Valid values are 'no_shard', 'optim', + 'optim_grads', 'optim_grads_params'.""" + + gradient_reduce_div_fusion: bool = True + """If true, perform gradient reduce and division fusion.""" + + suggested_communication_unit_size: int = 400_000_000 + """When batch communication is needed across multiple buckets, + this environment variable guides the size of communication unit size.""" + + preserve_fp32_weights: bool = True + """If true, preserve fp32 weights in the custom FSDP ParamAndGradBuffer.""" + + keep_fp8_transpose_cache_when_using_custom_fsdp: bool = False + """If true, keep the fp8 transpose cache when using custom FSDP.""" diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index db31fc0131459cf74b7ef44676992509ffa3c135..e04da87bd07cc79fe5fb7b877af9fc47b4114433 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -1,284 +1,325 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -from typing import List, Optional, Union - -import torch -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - -try: - from torch.distributed._tensor import DTensor, distribute_tensor - - HAVE_DTENSOR = True -except ImportError: - HAVE_DTENSOR = False - -from .. import parallel_state -from ..transformer.transformer_config import TransformerConfig -from ..utils import get_attr_wrapped_model, get_model_config - - -def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor: - """ - Unshards the input tensor if it is a DTensor and otherwise returns the - tensor unmodified. - - Args: - tensor (Union[torch.Tensor, DTensor]): The tensor to potentially unshard. - - Returns: - An unsharded version of the input tensor if it is a DTensor, or the - input tensor unmodified if it is not a DTensor. - """ - if HAVE_DTENSOR and isinstance(tensor, DTensor): - unsharded_tensor = tensor.full_tensor() - for k, v in vars(tensor).items(): - setattr(unsharded_tensor, k, v) - return unsharded_tensor - return tensor - - -def _reshard_if_dtensor( - tensor_to_shard: torch.Tensor, reference_tensor: Union[torch.Tensor, "DTensor"] -) -> Union[torch.Tensor, "DTensor"]: - """ - Reshards the input tensor to match the sharding configuration of the - reference tensor if the reference tensor is a DTensor. Otherwise, returns - the reference tensor unmodified. - - Args: - tensor_to_shard (torch.Tensor): The tensor to be potentially sharded. - reference_tensor (Union[torch.Tensor, DTensor]): The reference tensor - for the sharding configuration. - - Returns: - Union[torch.Tensor, DTensor]: The sharded tensor matching the reference tensor's - configuration, or the reference tensor itself if it is not a DTensor. - """ - if HAVE_DTENSOR and isinstance(reference_tensor, DTensor): - sharded_tensor = distribute_tensor( - tensor_to_shard, - device_mesh=reference_tensor.device_mesh, - placements=reference_tensor.placements, - ) - for k, v in vars(reference_tensor).items(): - setattr(sharded_tensor, k, v) - return sharded_tensor - return reference_tensor - - -def _allreduce_conditional_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce conditional embedding grads. - - Reduce grads across all the pp stages to ensure that parameters of the conditional embedders - (e.g., timestep embedder, FPS embedder, label embedder) stay in sync. - This is for the models with replicated embedders on each PP / VPP rank, like diffusion models. - """ - - if parallel_state.get_pipeline_model_parallel_world_size() > 1 and getattr( - config, "has_cond_embedder", False - ): - grads_dict = {} - for model_chunk in model: - for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')(): - if param.requires_grad and getattr(param, 'pipeline_parallel', False): - grad = param.main_grad - if name in grads_dict: - # Add all the virtual PP rank's gradients to - # the first local virtual PP rank. - grads_dict[name][0].add_(grad) - # Append to the end for later update after cross-rank reduce. - grads_dict[name].append(grad) - else: - grads_dict[name] = [grad] - if grads_dict: - # All-reduce the gradient on the first VPP rank. - grads = [param_grad[0] for _, param_grad in grads_dict.items()] - coalesced = _flatten_dense_tensors(grads) - torch.distributed.all_reduce( - coalesced, group=parallel_state.get_pipeline_model_parallel_group() - ) - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) - - # Update the gradients on other VPP ranks. - for grads in grads_dict.values(): - for grad in grads[1:]: - grad.copy_(grads[0]) - - -def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce word embedding grads. - - Reduce grads across first and last stages to ensure that word_embeddings parameters stay in - sync. - """ - - if ( - parallel_state.is_rank_in_embedding_group(ignore_virtual=True) - and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1 - ): - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - model_module = model[0] - elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): - model_module = model[-1] - else: # We do not support an interleaved schedule for models with encoders yet. - model_module = model[0] - - model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) - if model_module.share_embeddings_and_output_weights: - weight = model_module.shared_embedding_or_output_weight() - grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" - orig_grad = getattr(weight, grad_attr) - grad = _unshard_if_dtensor(orig_grad) - torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) - setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) - - -def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce position_embeddings grad across encoder and decoder stages to ensure that position - embeddings parameters stay in sync. - """ - if ( - parallel_state.is_rank_in_position_embedding_group() - and torch.distributed.get_world_size(parallel_state.get_position_embedding_group()) > 1 - ): - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - model_module = model[0] - elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): - model_module = model[-1] - else: # We do not support an interleaved schedule for models with encoders yet. - model_module = model[0] - - model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) - assert hasattr(model_module, 'position_embeddings') - weight = model_module.position_embeddings.weight - grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" - orig_grad = getattr(weight, grad_attr) - grad = _unshard_if_dtensor(orig_grad) - torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group()) - setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) - - -def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce both word and position embeddings. - """ - _allreduce_word_embedding_grads(model, config) - _allreduce_position_embedding_grads(model, config) - - -def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce layernorm grads (for sequence parallelism). - """ - - # All-reduce layernorm parameters across model parallel nodes - # when sequence parallelism is used - if parallel_state.get_tensor_model_parallel_world_size() > 1 and ( - config.sequence_parallel or config.qk_layernorm - ): - params = [] - grads = [] - for model_chunk in model: - for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')(): - if ( - param.requires_grad - and getattr(param, 'sequence_parallel', False) - or 'q_layernorm' in name - or 'k_layernorm' in name - ): - params.append(param) - grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad" - grad = getattr(param, grad_attr) - grad = _unshard_if_dtensor(grad) - grads.append(grad.data) - if grads: - coalesced = _flatten_dense_tensors(grads) - torch.distributed.all_reduce( - coalesced, group=parallel_state.get_tensor_model_parallel_group() - ) - for param, buf, synced in zip( - params, grads, _unflatten_dense_tensors(coalesced, grads) - ): - buf.copy_(synced) - grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad" - orig_grad = getattr(param, grad_attr) - setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad)) - - -def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None): - """ - All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, - embedding grads across first and last pipeline stages (if not tied), - scale gradients by `num_tokens`. - """ - - config = get_model_config(model[0]) - - # All-reduce / reduce-scatter across DP replicas. - if config.timers is not None: - config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time) - for model_chunk in model: - model_chunk.finish_grad_sync() - if config.timers is not None: - config.timers('all-grads-sync').stop() - - # All-reduce t_embedder grads (for pp & vpp of DiT). - if config.timers is not None: - config.timers('conditional-embedder-grads-all-reduce', log_level=1).start( - barrier=config.barrier_with_L1_time - ) - _allreduce_conditional_embedding_grads(model, config) - if config.timers is not None: - config.timers('conditional-embedder-grads-all-reduce').stop() - - # All-reduce layer-norm grads (for sequence parallelism). - if config.timers is not None: - config.timers('layernorm-grads-all-reduce', log_level=1).start( - barrier=config.barrier_with_L1_time - ) - _allreduce_layernorm_grads(model, config) - if config.timers is not None: - config.timers('layernorm-grads-all-reduce').stop() - - # All-reduce embedding grads (for pipeline parallelism). - if config.timers is not None: - config.timers('embedding-grads-all-reduce', log_level=1).start( - barrier=config.barrier_with_L1_time - ) - _allreduce_embedding_grads(model, config) - if config.timers is not None: - config.timers('embedding-grads-all-reduce').stop() - - # normalize gradients for per-token loss normalization. - # if we are using by the number of tokens, then we use that as a divisor. this number - # will be the total number of non-padded tokens in the global batch. - if num_tokens is not None: - - # the number of tokens is only present on the last stage, so broadcast it - # to the other ranks in the pipeline parallel group. - last_rank = parallel_state.get_pipeline_model_parallel_last_rank() - pp_group = parallel_state.get_pipeline_model_parallel_group() - - if not isinstance(last_rank, list): - assert not isinstance(last_rank, list) - last_rank = [last_rank] - assert not isinstance(pp_group, list) - pp_group = [pp_group] - - # need to do a broadcast for every pp group, even though num_tokens should be the same. - num_tokens_list = [] - for lr, group in zip(last_rank, pp_group): - torch.distributed.broadcast(num_tokens, src=lr, group=group) - num_tokens_list.append(torch.clone(num_tokens)) - assert all(x.item() == num_tokens_list[0] for x in num_tokens_list) - - # all-reduce across DP ranks. - torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group()) - for model_chunk in model: - if num_tokens > 0: - scaling = 1.0 / num_tokens - model_chunk.scale_gradients(scaling) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import List, Optional, Union + +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +try: + from torch.distributed._tensor import DTensor, distribute_tensor + + HAVE_DTENSOR = True +except ImportError: + HAVE_DTENSOR = False + +from .. import parallel_state +from ..transformer.moe.moe_utils import get_updated_expert_bias +from ..transformer.transformer_config import TransformerConfig +from ..utils import get_attr_wrapped_model, get_model_config + + +def _get_main_grad_attr(param: torch.nn.Parameter, use_custom_fsdp: bool = False): + if use_custom_fsdp: + return "fsdp_managed_main_grad" + if hasattr(param, "main_grad"): + return "main_grad" + return "grad" + + +def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor: + """ + Unshards the input tensor if it is a DTensor and otherwise returns the + tensor unmodified. + + Args: + tensor (Union[torch.Tensor, DTensor]): The tensor to potentially unshard. + + Returns: + An unsharded version of the input tensor if it is a DTensor, or the + input tensor unmodified if it is not a DTensor. + """ + if HAVE_DTENSOR and isinstance(tensor, DTensor): + unsharded_tensor = tensor.full_tensor() + for k, v in vars(tensor).items(): + setattr(unsharded_tensor, k, v) + return unsharded_tensor + return tensor + + +def _reshard_if_dtensor( + tensor_to_shard: torch.Tensor, reference_tensor: Union[torch.Tensor, "DTensor"] +) -> Union[torch.Tensor, "DTensor"]: + """ + Reshards the input tensor to match the sharding configuration of the + reference tensor if the reference tensor is a DTensor. Otherwise, returns + the reference tensor unmodified. + + Args: + tensor_to_shard (torch.Tensor): The tensor to be potentially sharded. + reference_tensor (Union[torch.Tensor, DTensor]): The reference tensor + for the sharding configuration. + + Returns: + Union[torch.Tensor, DTensor]: The sharded tensor matching the reference tensor's + configuration, or the reference tensor itself if it is not a DTensor. + """ + if HAVE_DTENSOR and isinstance(reference_tensor, DTensor): + sharded_tensor = distribute_tensor( + tensor_to_shard, + device_mesh=reference_tensor.device_mesh, + placements=reference_tensor.placements, + ) + for k, v in vars(reference_tensor).items(): + setattr(sharded_tensor, k, v) + return sharded_tensor + return reference_tensor + + +def _allreduce_conditional_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): + """ + All-reduce conditional embedding grads. + + Reduce grads across all the pp stages to ensure that parameters of the conditional embedders + (e.g., timestep embedder, FPS embedder, label embedder) stay in sync. + This is for the models with replicated embedders on each PP / VPP rank, like diffusion models. + """ + + if parallel_state.get_pipeline_model_parallel_world_size() > 1 and getattr( + config, "has_cond_embedder", False + ): + grads_dict = {} + for model_chunk in model: + for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')(): + if param.requires_grad and getattr(param, 'pipeline_parallel', False): + grad = param.main_grad + if name in grads_dict: + # Add all the virtual PP rank's gradients to + # the first local virtual PP rank. + grads_dict[name][0].add_(grad) + # Append to the end for later update after cross-rank reduce. + grads_dict[name].append(grad) + else: + grads_dict[name] = [grad] + if grads_dict: + # All-reduce the gradient on the first VPP rank. + grads = [param_grad[0] for _, param_grad in grads_dict.items()] + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce( + coalesced, group=parallel_state.get_pipeline_model_parallel_group() + ) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + # Update the gradients on other VPP ranks. + for grads in grads_dict.values(): + for grad in grads[1:]: + grad.copy_(grads[0]) + + +def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): + """ + All-reduce word embedding grads. + + Reduce grads across first and last stages to ensure that word_embeddings parameters stay in + sync. + """ + + if ( + parallel_state.is_rank_in_embedding_group(ignore_virtual=True) + and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1 + ): + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + model_module = model[0] + elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): + model_module = model[-1] + else: # We do not support an interleaved schedule for models with encoders yet. + model_module = model[0] + + ddp_config = model_module.ddp_config + model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) + if model_module.share_embeddings_and_output_weights: + weight = model_module.shared_embedding_or_output_weight() + grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp) + orig_grad = getattr(weight, grad_attr) + grad = _unshard_if_dtensor(orig_grad) + torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) + setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) + + +def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): + """ + All-reduce position_embeddings grad across encoder and decoder stages to ensure that position + embeddings parameters stay in sync. + """ + if ( + parallel_state.is_rank_in_position_embedding_group() + and torch.distributed.get_world_size(parallel_state.get_position_embedding_group()) > 1 + ): + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + model_module = model[0] + elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): + model_module = model[-1] + else: # We do not support an interleaved schedule for models with encoders yet. + model_module = model[0] + + ddp_config = model_module.ddp_config + model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) + assert hasattr(model_module, 'position_embeddings') + weight = model_module.position_embeddings.weight + grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp) + orig_grad = getattr(weight, grad_attr) + grad = _unshard_if_dtensor(orig_grad) + torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group()) + setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) + + +def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): + """ + All-reduce both word and position embeddings. + """ + _allreduce_word_embedding_grads(model, config) + _allreduce_position_embedding_grads(model, config) + + +def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig): + """ + All-reduce layernorm grads (for sequence parallelism). + """ + + # All-reduce layernorm parameters across model parallel nodes + # when sequence parallelism is used + if parallel_state.get_tensor_model_parallel_world_size() > 1 and ( + config.sequence_parallel or config.qk_layernorm + ): + params = [] + grads = [] + for model_chunk in model: + for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')(): + if param.requires_grad and ( + getattr(param, 'sequence_parallel', False) + or 'q_layernorm' in name + or 'k_layernorm' in name + ): + params.append(param) + grad_attr = _get_main_grad_attr(param, config.use_custom_fsdp) + grad = getattr(param, grad_attr) + grad = _unshard_if_dtensor(grad) + grads.append(grad.data) + if grads: + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce( + coalesced, group=parallel_state.get_tensor_model_parallel_group() + ) + for param, buf, synced in zip( + params, grads, _unflatten_dense_tensors(coalesced, grads) + ): + buf.copy_(synced) + grad_attr = _get_main_grad_attr(param, config.use_custom_fsdp) + orig_grad = getattr(param, grad_attr) + setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad)) + + +def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig): + """ + Update the expert bias of the router for a global batch. + This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks + """ + tokens_per_expert_list = [] + expert_bias_list = [] + for model_chunk in model: + for module in get_attr_wrapped_model(model_chunk, 'modules')(): + if hasattr(module, 'expert_bias'): + tokens_per_expert_list.append(module.local_tokens_per_expert) + expert_bias_list.append(module.expert_bias) + # For hybrid models with both MoE and Dense layers, this list can be empty. + if len(expert_bias_list) == 0: + return + stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0) + stacked_expert_bias = torch.stack(expert_bias_list, dim=0) + stacked_updated_expert_bias = get_updated_expert_bias( + stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate + ) + + for tokens_per_expert, expert_bias, updated_expert_bias in zip( + tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias + ): + tokens_per_expert.zero_() + expert_bias.copy_(updated_expert_bias) + + +def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None): + """ + All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, + embedding grads across first and last pipeline stages (if not tied), + scale gradients by `num_tokens`. + """ + + config = get_model_config(model[0]) + + # All-reduce / reduce-scatter across DP replicas. + if config.timers is not None: + config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time) + for model_chunk in model: + model_chunk.finish_grad_sync() + if config.timers is not None: + config.timers('all-grads-sync').stop() + + # All-reduce t_embedder grads (for pp & vpp of DiT). + if config.timers is not None: + config.timers('conditional-embedder-grads-all-reduce', log_level=1).start( + barrier=config.barrier_with_L1_time + ) + _allreduce_conditional_embedding_grads(model, config) + if config.timers is not None: + config.timers('conditional-embedder-grads-all-reduce').stop() + + # All-reduce layer-norm grads (for sequence parallelism). + if config.timers is not None: + config.timers('layernorm-grads-all-reduce', log_level=1).start( + barrier=config.barrier_with_L1_time + ) + _allreduce_layernorm_grads(model, config) + if config.timers is not None: + config.timers('layernorm-grads-all-reduce').stop() + + # All-reduce embedding grads (for pipeline parallelism). + if config.timers is not None: + config.timers('embedding-grads-all-reduce', log_level=1).start( + barrier=config.barrier_with_L1_time + ) + _allreduce_embedding_grads(model, config) + if config.timers is not None: + config.timers('embedding-grads-all-reduce').stop() + + if config.moe_router_enable_expert_bias: + _update_router_expert_bias(model, config) + + # normalize gradients for per-token loss normalization. + # if we are using by the number of tokens, then we use that as a divisor. this number + # will be the total number of non-padded tokens in the global batch. + if num_tokens is not None: + + # the number of tokens is only present on the last stage, so broadcast it + # to the other ranks in the pipeline parallel group. + last_rank = parallel_state.get_pipeline_model_parallel_last_rank() + pp_group = parallel_state.get_pipeline_model_parallel_group() + + if not isinstance(last_rank, list): + assert not isinstance(last_rank, list) + last_rank = [last_rank] + assert not isinstance(pp_group, list) + pp_group = [pp_group] + + # need to do a broadcast for every pp group, even though num_tokens should be the same. + num_tokens_list = [] + for lr, group in zip(last_rank, pp_group): + torch.distributed.broadcast(num_tokens, src=lr, group=group) + num_tokens_list.append(torch.clone(num_tokens)) + assert all(x.item() == num_tokens_list[0] for x in num_tokens_list) + + # all-reduce across DP ranks. + torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group()) + for model_chunk in model: + if num_tokens > 0: + scaling = 1.0 / num_tokens + model_chunk.scale_gradients(scaling) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 5095a7c7f3b44f8a7040d3263873aebd2a76b681..5929498b9790804da8f15490afa7839286bdccda 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -1,836 +1,882 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import logging -import math -from contextlib import nullcontext -from enum import Enum -from typing import Dict, List, Optional - -import torch -from torch.distributed import _coalescing_manager - -from megatron.core.rerun_state_machine import get_rerun_state_machine - -from ..utils import is_float8tensor, is_torch_min_version, log_on_each_pipeline_stage -from .distributed_data_parallel_config import DistributedDataParallelConfig - -logger = logging.getLogger(__name__) - - -if is_torch_min_version("1.13.0"): - dist_all_gather_func = torch.distributed.all_gather_into_tensor - dist_reduce_scatter_func = torch.distributed.reduce_scatter_tensor -else: - dist_all_gather_func = torch.distributed._all_gather_base - dist_reduce_scatter_func = torch.distributed._reduce_scatter_base - - -class BufferType(Enum): - """ - Enumeration for buffer type. - """ - - PARAM = 1 - GRAD = 2 - - -def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int): - """ - Shard buffer into data_parallel_world_size chunks of equal size. - """ - assert buffer.numel() % data_parallel_world_size == 0 - shard_size = buffer.numel() // data_parallel_world_size - sharded_buffer = [ - buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size) - ] - return sharded_buffer - - -class _ParamAndGradBucket: - """ - Bucket to keep track of a subset of the model's parameters and gradients. - - Args: - params: List of parameters whose gradients are collated in this bucket. - param_data: View in _ParamAndGradBuffer.param_data that this bucket is responsible for. - grad_data: View in _ParamAndGradBuffer.grad_data that this bucket is responsible for. - offset: Offset of this bucket's view in the larger _ParamAndGradBuffer. - numel_unpadded: Number of unpadded elements in bucket. - gradient_scaling_factor: This factor is utilized to scale gradients prior to their - communication. Its application is twofold: it facilitates the averaging of gradients - and the scaling of gradients in the context of the Mixture of Experts (MoE) model. - bucket_id: Index of bucket in buffer. - """ - - def __init__( - self, - params: List[torch.nn.Parameter], - param_data: Optional[torch.Tensor], - grad_data: torch.Tensor, - offset: int, - numel_unpadded: int, - gradient_scaling_factor: float, - bucket_id: int, - ): - self.params_list = params - self.params = set(params) - # Make sure there are no duplicate params. - assert len(self.params_list) == len(self.params) - self.param_data = param_data - self.grad_data = grad_data - # The distributed optimizer needs to keep track of this bucket's offset - # within the full grad_buffer. - self.offset = offset - self.numel_unpadded = numel_unpadded - self.gradient_scaling_factor = gradient_scaling_factor - self.bucket_id = bucket_id - - -class _ParamAndGradBucketGroup: - """ - Put multiple buckets into a group so that their communications can be aggregated together. - Provides functionality to register when params in the bucket group have grads ready to be - synced; an asynchronous communication call is automatically launched when _all_ params in - the bucket group have grads ready. - - Args: - buckets: A list of buckets. - ddp_config: DistributedDataParallel config object. - collective_group: intra_distributed_optimizer_instance_group if using distributed - optimizer, data_parallel_group if not. - collective_group_size: World size using the intra data-parallel group. - """ - - def __init__( - self, - buckets: List[_ParamAndGradBucket], - ddp_config: DistributedDataParallelConfig, - collective_group: torch.distributed.ProcessGroup, - collective_group_size: int, - ): - self.buckets = buckets - self.ddp_config = ddp_config - - if self.ddp_config.use_distributed_optimizer: - self.intra_distributed_optimizer_instance_group = collective_group - self.intra_distributed_optimizer_instance_size = collective_group_size - self.intra_distributed_optimizer_instance_rank = torch.distributed.get_rank( - group=collective_group - ) - else: - self.data_parallel_group = collective_group - - # State for bookkeeping: params is the set of parameters this bucket group is - # responsible for, params_with_grad is the set of parameters with grads - # available. When overlap_grad_reduce is True, communication (all-reduce - # or reduce-scatter) is issued when params_with_grad equals params. - self.param_to_bucket = {} - self.params = set() - for bucket in self.buckets: - for param in bucket.params_list: - self.param_to_bucket[param] = bucket - self.params.add(param) - - self.next_param_gather_bucket_group = None - - if self.ddp_config.num_distributed_optimizer_instances > 1: - self.inter_distributed_optimizer_instance_group = None - self.communication_stream = None - - self.reset() - self.param_gather_handle = None - self.param_gather_dispatched = False - self.grad_reduce_handle = None - - def reset(self): - """ - Reset metadata in bucket group in preparation for the next iteration of training. - """ - self.params_with_grad = set() - self.is_last_microbatch = True - - def check_for_nan_in_grad(self): - """ - Make sure norm of grads in bucket are not NaN prior to data-parallel - all-reduce / reduce-scatter. - """ - rerun_state_machine = get_rerun_state_machine() - for i in range(len(self.buckets)): - rerun_state_machine.validate_result( - result=self.buckets[i].grad_data.norm(p=2), - rejection_func=torch.isnan, - message=f"found NaN in local grad norm for bucket #{i} " - f"in backward pass before data-parallel communication collective", - tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward - fatal=True, - ) - - def start_param_sync(self, force_sync: bool = False): - """ - Initiates all necessary param all-gathers for this bucket. - - When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous - communication call (unless force_sync is True). When ddp_config.overlap_param_gather - is set to False, makes synchronous call. - - Args: - force_sync (bool, optional): force synchronous collective regardless of - other settings if true. - """ - assert self.ddp_config.use_distributed_optimizer - - if force_sync: - if self.param_gather_handle is not None: - self.param_gather_handle.wait() - self.param_gather_handle = None - return - else: - assert self.param_gather_handle is None - - async_op = self.ddp_config.overlap_param_gather and not force_sync - # Coalesce communication kernels across buckets in the bucket group. - with _coalescing_manager( - self.intra_distributed_optimizer_instance_group, async_ops=async_op - ) as cm: - for bucket in self.buckets: - local_data_view = shard_buffer( - bucket.param_data, self.intra_distributed_optimizer_instance_size - )[self.intra_distributed_optimizer_instance_rank] - dist_all_gather_func( - bucket.param_data, - local_data_view, - group=self.intra_distributed_optimizer_instance_group, - async_op=async_op, - ) - if async_op: - self.param_gather_handle = cm - else: - # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, - # `cm` is not None, which is different from when `_coalescing_manager` is not used in - # which case the torch.distributed._all_gather_base() will return None. In order to - # maintain consistency with prior code, we need to manually set communication handle to - # None. - self.param_gather_handle = None - self.param_gather_dispatched = True - - def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): - """ - Finishes param sync communication operation for this bucket. Dispatches - next bucket's param sync if available, unless skip_next_bucket_dispatch - is True. - - When ddp_config.overlap_param_gather is set to True, waits for asynchronous - communication call to complete (and dispatches one if one is not already - outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to - False. - - Args: - skip_next_bucket_dispatch (bool, optional): if true, dispatch next - bucket's communication if available. - """ - assert self.ddp_config.use_distributed_optimizer - assert self.ddp_config.overlap_param_gather - - # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first - # AG bucket in first model chunk if ddp_config.align_param_gather is False). - if not self.param_gather_dispatched: - self.start_param_sync() - - if self.param_gather_handle is not None: - self.param_gather_handle.wait() - self.param_gather_handle = None - # Dispatch next bucket's asynchronous param AG. - if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch: - self.next_param_gather_bucket_group.start_param_sync() - - def start_grad_sync(self): - """ - Initiates grad sync (all-reduce or reduce-scatter) communication operations - for all buckets in the bucket group. - - When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous - communication call. When ddp_config.overlap_grad_reduce is set to False, makes - synchronous call. - """ - assert ( - self.grad_reduce_handle is None - ), 'Should not have multiple communication calls outstanding at once' - - if self.ddp_config.check_for_nan_in_grad: - self.check_for_nan_in_grad() - - # gradient_scaling_factor already takes into account whether we are computing - # an average or sum in the data-parallel collective. - for bucket in self.buckets: - if bucket.gradient_scaling_factor != 1.0: - bucket.grad_data *= bucket.gradient_scaling_factor - - # Decide reduce_op. - reduce_op = torch.distributed.ReduceOp.SUM - if self.ddp_config.average_in_collective: - reduce_op = torch.distributed.ReduceOp.AVG - - # We use the following stream synchronization for the gradient reduction - # within and across DistOpt instances. - - # Compute Stream: -------------Gradient compute------------------- - # Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)------- - # NCCL Stream: -------RS------ -------AR------ - - # Use async communications only when overlap_grad_reduce is True. - async_op = ( - self.ddp_config.overlap_grad_reduce - and self.ddp_config.num_distributed_optimizer_instances == 1 - ) - if ( - self.ddp_config.num_distributed_optimizer_instances > 1 - and self.ddp_config.overlap_grad_reduce - ): - # Assign a communication stream if we have multiple DistOpt instances and we - # need to overlap communication. - stream_context = torch.cuda.stream(self.communication_stream) - - # The RS/AR communication stream needs to wait for the default stream - # to complete its gradient computation before launching the next - # gradient reduction collective. - self.communication_stream.wait_stream(torch.cuda.default_stream()) - else: - stream_context = nullcontext() - - if self.ddp_config.use_distributed_optimizer: - communication_group = self.intra_distributed_optimizer_instance_group - else: - communication_group = self.data_parallel_group - - # Coalesce communication kernels across buckets in the bucket group. - with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm: - for bucket in self.buckets: - if self.ddp_config.use_distributed_optimizer: - local_data_view = shard_buffer( - bucket.grad_data, self.intra_distributed_optimizer_instance_size - )[self.intra_distributed_optimizer_instance_rank] - dist_reduce_scatter_func( - local_data_view, - bucket.grad_data, - op=reduce_op, - group=communication_group, - async_op=async_op, - ) - else: - torch.distributed.all_reduce( - bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op - ) - - # With multiple DistOpt instances, we need to all-reduce across instances. - if ( - self.ddp_config.use_distributed_optimizer - and self.ddp_config.num_distributed_optimizer_instances > 1 - ): - - # Create a new coalescing manager for the inter-instance all-reduce. - with stream_context, _coalescing_manager( - self.inter_distributed_optimizer_instance_group, async_ops=async_op - ) as cm: - for bucket in self.buckets: - local_data_view = shard_buffer( - bucket.grad_data, self.intra_distributed_optimizer_instance_size - )[self.intra_distributed_optimizer_instance_rank] - - torch.distributed.all_reduce( - local_data_view, - op=reduce_op, - group=self.inter_distributed_optimizer_instance_group, - async_op=async_op, - ) - - if async_op: - self.grad_reduce_handle = cm - else: - # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, - # `cm` is not None, which is different from when `_coalescing_manager` is not used in - # which case the torch.distributed._reduce_scatter_base() will return None. In order to - # maintain consistency with prior code, we need to manually set communication handle to - # None. - self.grad_reduce_handle = None - - def finish_grad_sync(self): - """ - Finishes grad sync (all-reduce or reduce-scatter) communication operations - for all buckets in the bucket group. - - When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous - communication call to complete. When ddp_config.overlap_grad_reduce is set to False, - makes synchronous call. - """ - self.param_gather_dispatched = False - # If overlap_grad_reduce is False, start (and finish) synchronous communication call here. - if not self.ddp_config.overlap_grad_reduce: - self.start_grad_sync() - return - # When using multiple DistOpt instances, we don't need to sync here as we launch - # communications on a separate communication stream. - if self.ddp_config.num_distributed_optimizer_instances > 1: - torch.cuda.default_stream().wait_stream(self.communication_stream) - return - assert self.grad_reduce_handle is not None, ( - f'Communication call has not been issued for this bucket ' - f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)' - ) - self.grad_reduce_handle.wait() - self.grad_reduce_handle = None - - def register_grad_ready(self, param: torch.nn.Parameter): - """ - Registers grads for the passed-in param to be "ready" for grad sync. - - When the number of microbatches is greater than 1, we only want to register - grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce - is True. - """ - assert ( - self.ddp_config.overlap_grad_reduce - ), 'register_grad_ready() should only be called when overlap_grad_reduce is True' - if self.is_last_microbatch: - assert param in self.param_to_bucket, 'Param is not in the bucket group' - assert param not in self.params_with_grad, 'Cannot set grad twice' - self.params_with_grad.add(param) - # If all params in bucket group have grads available, issue communication call. - if len(self.params_with_grad) == len(self.params): - self.start_grad_sync() - - -class _ParamAndGradBuffer: - """ - Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into - buckets with roughly `bucket_size` parameters each. - - Args: - ddp_config: DistributedDataParallel config object. - param_dtype: Type of param tensor. - grad_dtype: Type of grad tensor. - params: List of parameters whose parameters and gradients are collated in the underlying - tensor. - data_parallel_group: Data-parallel process group. - bucket_size: The rough size of each bucket in terms of number of parameters. - param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes). - gradient_scaling_factor: This factor is utilized to scale gradients prior to their - communication. Its application is twofold: it facilitates the averaging of gradients - and the scaling of gradients in the context of the Mixture of Experts (MoE) model. - param_indices: The index of each param among the params with same dtype, if a param is fp8, - use its "fake" high precision dtype to determine which params have same dtype with it. - These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode. - """ - - def __init__( - self, - ddp_config: DistributedDataParallelConfig, - param_dtype: torch.dtype, - grad_dtype: torch.dtype, - params: List[torch.nn.Parameter], - data_parallel_group: torch.distributed.ProcessGroup, - bucket_size: int, - param_to_name: Dict[torch.nn.Parameter, str], - gradient_scaling_factor: float, - param_indices: List[int], - ): - self.ddp_config = ddp_config - self.params = params - self.param_indices = param_indices - - # Check that params are unique. - unique_params = set() - for param in params: - assert param not in unique_params - unique_params.add(param) - del unique_params - - # Store attributes that will be needed later. - self.param_dtype = param_dtype - self.grad_dtype = grad_dtype - self.data_parallel_group = data_parallel_group - self.data_parallel_world_size = torch.distributed.get_world_size( - group=self.data_parallel_group - ) - self.gradient_scaling_factor = gradient_scaling_factor - - # Data structures to store underlying buckets and relevant indexing data. - self.buckets = [] - self.param_to_bucket = {} # Param -> bucket mapping. - self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer). - - def _pad(number_to_be_padded: int, divisor: int) -> int: - return int(math.ceil(number_to_be_padded / divisor) * divisor) - - def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int: - """ - Pads end index of bucket if using distributed optimizer (to ensure uniform sharding). - """ - if self.ddp_config.use_distributed_optimizer: - # Workaround for TE bug causing cuBLAS to pick an incompatible algorithm. - # This also helps cuBLAS pick more efficient algorithms for GEMMs. - # We now ensure that all buckets start at a memory address that is 256-byte - # aligned (128 values since params and grads use >= 16-bit precision). - return _pad(bucket_end_index, math.lcm(self.data_parallel_world_size, 128)) - return bucket_end_index - - def _pad_start_of_param_if_needed(param_start_index: int) -> int: - """ - Pads start index of param if using distributed optimizer (to ensure "good" alignment). - """ - if self.ddp_config.use_distributed_optimizer: - # Ensure that params start at 128-byte aligned addresses (64 values - # since params are >= 16-bit precision). - return _pad(param_start_index, 64) - return param_start_index - - # First, figure out how many elements should be in the underlying buffer storage. - # Note that if we need to split the buffer into smaller buckets, each of these - # might need to be padded as well (if using the distributed optimizer). - param_start_index = 0 - bucket_start_index = param_start_index - bucket_params = set() - self.bucket_indices = [] - per_bucket_numel_unpadded = [] - bucket_id = 0 - - def _update_bucket_metadata(param_end_index: int) -> int: - """ - Record metadata for the bucket starting at bucket_start_index and ending with the - passed-in param_end_index. Returns the bucket's end_index. - """ - nonlocal bucket_start_index, bucket_params, bucket_id - per_bucket_numel_unpadded.append(param_end_index - bucket_start_index) - bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index) - - # Record metadata of new bucket. - self.bucket_indices.append((bucket_start_index, bucket_end_index)) - bucket_start_index = bucket_end_index - - # Prepare for next bucket. - bucket_params = set() - bucket_id += 1 - - # Return the potentially padded bucket_end_index. - return bucket_end_index - - def _does_param_require_new_bucket(param): - """ - Split shared embedding parameters into separate bucket if using distributed - optimizer that makes use of reduce-scatters instead of all-reduces. - This ensures that the first and last pipeline stage partition optimizer state - for the shared embedding parameters the same way across DP replicas, allowing - the DP reduce-scatter to be before the embedding all-reduce. - """ - return ( - getattr(param, "shared_embedding", False) - and self.ddp_config.use_distributed_optimizer - ) - - for param in params[::-1]: - # Iterate through parameters in reverse order to roughly follow backprop order. - - this_numel = param.data.nelement() - param_start_index = _pad_start_of_param_if_needed(param_start_index) - - # Create bucket with collected parameters if current param needs its own bucket. - if _does_param_require_new_bucket(param): - # We are creating a bucket for the already accumulated parameters, whose params - # end at the current param_start_index. - if self.ddp_config.use_distributed_optimizer: - # Make sure new bucket is appropriately padded. - if param_start_index % self.data_parallel_world_size != 0: - param_start_index = _pad_end_of_bucket_if_needed(param_start_index) - if len(bucket_params) > 0: - bucket_end_index = _update_bucket_metadata(param_start_index) - - param_end_index = param_start_index + this_numel - self.param_index_map[param] = (param_start_index, param_end_index, bucket_id) - bucket_params.add(param) - - # If we have enough elements already or the current param is part of the shared - # embedding layer and needs a separate bucket, form a new bucket. - if ( - bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size - ) or _does_param_require_new_bucket(param): - bucket_end_index = _update_bucket_metadata(param_end_index) - param_start_index = bucket_end_index - else: - param_start_index = param_end_index - - # Add remaining params to a new bucket. - if len(bucket_params) > 0: - bucket_end_index = _update_bucket_metadata(param_end_index) - - # Next, create underlying storage for buffer (with numel elements that includes - # padding as necessary). - self.numel = bucket_end_index - self.numel_unpadded = sum(per_bucket_numel_unpadded) - assert self.numel_unpadded <= self.numel - if self.ddp_config.use_distributed_optimizer: - assert self.numel % self.data_parallel_world_size == 0 - else: - assert self.numel == self.numel_unpadded - - self.param_data = None - # Only re-map param tensors if using distributed optimizer. - if self.ddp_config.use_distributed_optimizer: - self.param_data = torch.zeros( - self.numel, - dtype=self.param_dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - self.grad_data = torch.zeros( - self.numel, - dtype=self.grad_dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - - # Finally, map param.data and param.main_grad fields to buffers. - bucket_params = [] - bucket_start_index = 0 - cur_bucket_id = 0 - for param in params[::-1]: - param_start_index, param_end_index, bucket_id = self.param_index_map[param] - - # Assign param.data to appropriate segment of self.param_data. - if self.param_data is not None: - old_param_data = param.data - new_param_data = self._get( - param.data.shape, param_start_index, buffer_type=BufferType.PARAM - ) - if is_float8tensor(param): - param._data = new_param_data - else: - param.data = new_param_data - assert old_param_data._base is None - # Copy tensor values (from initialization or checkpoint). - param.data.detach().copy_(old_param_data) - del old_param_data - - param.main_grad = self._get( - param.data.shape, param_start_index, buffer_type=BufferType.GRAD - ) - if bucket_id != cur_bucket_id: - bucket_end_index = _pad_end_of_bucket_if_needed(param_start_index) - self.buckets.append( - self._new_bucket( - bucket_params=bucket_params, - start_index=bucket_start_index, - end_index=bucket_end_index, - numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], - bucket_id=cur_bucket_id, - ) - ) - bucket_start_index = bucket_end_index - bucket_params = [] - assert cur_bucket_id + 1 == len(self.buckets) - assert bucket_id == cur_bucket_id + 1 - cur_bucket_id = bucket_id - bucket_params.append(param) - - # Add remaining params to a new bucket. - if len(bucket_params) > 0: - bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index) - self.buckets.append( - self._new_bucket( - bucket_params=bucket_params, - start_index=bucket_start_index, - end_index=bucket_end_index, - numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], - bucket_id=cur_bucket_id, - ) - ) - - # Log buckets for all PP stages. - log_strs = [] - log_strs.append( - f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}' - ) - for index, bucket in enumerate(self.buckets): - numel = 0 - for param in bucket.params: - numel += param.data.nelement() - log_strs.append(f'Params for bucket {index+1} ({numel} elements):') - for param in bucket.params: - log_strs.append(f'\t{param_to_name[param]}') - log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs)) - - def scale_gradients(self, scaling_factor: float) -> None: - """Scale the gradient data by `scaling_factor`.""" - self.grad_data *= scaling_factor - - def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor: - """ - Return a tensor with the input `shape` as a view into the 1-D data starting at - `start_index`. - """ - end_index = start_index + shape.numel() - assert end_index <= self.numel, 'Requested tensor is out of buffer range' - if buffer_type == BufferType.PARAM: - assert self.param_data is not None - buffer_tensor = self.param_data[start_index:end_index] - elif buffer_type == BufferType.GRAD: - buffer_tensor = self.grad_data[start_index:end_index] - else: - raise Exception("Illegal buffer type provided to GradBuffer._get() function") - buffer_tensor = buffer_tensor.view(shape) - return buffer_tensor - - def _new_bucket( - self, - bucket_params: List[torch.nn.Parameter], - start_index: int, - end_index: int, - numel_unpadded: int, - bucket_id: int, - ) -> _ParamAndGradBucket: - """ - Helper function that creates a new bucket. Also updates param->bucket mapping. - """ - - # Assert that indices are correctly padded (if needed), and that bucket - # position is same as originally computed. - if self.ddp_config.use_distributed_optimizer: - assert start_index % self.data_parallel_world_size == 0 - assert end_index % self.data_parallel_world_size == 0 - assert (start_index, end_index) == self.bucket_indices[bucket_id] - - # Get appropriate view into global _ParamAndGradBuffer. - bucketed_param_data = None - if self.param_data is not None: - bucketed_param_data = self._get( - torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM - ) - bucketed_grad_data = self._get( - torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD - ) - bucket = _ParamAndGradBucket( - params=bucket_params, - param_data=bucketed_param_data, - grad_data=bucketed_grad_data, - offset=start_index, - numel_unpadded=numel_unpadded, - gradient_scaling_factor=self.gradient_scaling_factor, - bucket_id=bucket_id, - ) - for bucket_param in bucket_params: - assert bucket_param not in self.param_to_bucket - self.param_to_bucket[bucket_param] = bucket - - return bucket - - def reset(self): - """ - Zero out the underlying grad_buffer. - """ - self.grad_data.zero_() - - -def partition_buckets( - buffers: List[_ParamAndGradBuffer], force_single_bucket_group: bool = False -) -> List[_ParamAndGradBucketGroup]: - """ - Automatically regroup the buckets of input buffers and return a list of bucket groups. - - In some scenarios, we need to put buckets from different buffers into a group so that their - communication can be aggregated. - - For example, when there are both fp8 weights and bf16 biases in the model and virtual - pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket, - which doubles the number of communication kernels, and because of the use of - CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the - overlap of communication kernels with computation kernels. - - The grouping strategy is: - 1. If force_single_bucket_group is True, put all buckets across all buffers into a single - bucket group. - 2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers, - let each bucket group have only one bucket. - 3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets - into the last fp8 bucket group. - - Since the non-fp8 parameters (typically the biases of various layers) are relatively - small, they are likely to be grouped into a single non-fp8 bucket. - - The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to - the end of the model, while the last bucket corresponds to the beginning. - - If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the - reduce-scatter to synchronize gradients after the backward pass at the end of the model - has completed. This is because we need to wait for the non-fp8 params from the beginning - layers to obtain their gradients. - - Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue. - - Args: - buffers (list): list of input buffers. - single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer - into a single bucket group. - """ - - if len(buffers) == 0: - return [] - - dtype_to_buffer_map = {} - for buffer in buffers: - dtype = buffer.param_dtype - # Make sure that the param_dtype of any two buffers is different. - assert dtype not in dtype_to_buffer_map - dtype_to_buffer_map[dtype] = buffer - - # Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True. - if force_single_bucket_group: - buckets = [] - ddp_config = buffers[0].ddp_config - data_parallel_group = buffers[0].data_parallel_group - data_parallel_world_size = buffers[0].data_parallel_world_size - for buffer in buffers: - assert ddp_config == buffer.ddp_config - assert data_parallel_group == buffer.data_parallel_group - assert data_parallel_world_size == buffer.data_parallel_world_size - buckets.extend(buffer.buckets) - - bucket_group = _ParamAndGradBucketGroup( - buckets, ddp_config, data_parallel_group, data_parallel_world_size - ) - return [bucket_group] - - if torch.uint8 not in dtype_to_buffer_map: - # Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have - # only one bucket. - bucket_groups = [] - for buffer in buffers: - for bucket in buffer.buckets: - bucket_groups.append( - _ParamAndGradBucketGroup( - [bucket], - buffer.ddp_config, - buffer.data_parallel_group, - buffer.data_parallel_world_size, - ) - ) - return bucket_groups - else: - # Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group. - non_fp8_buckets = [] - for buffer in buffers: - if buffer.param_dtype != torch.uint8: - for bucket in buffer.buckets: - non_fp8_buckets.append(bucket) - - bucket_groups = [] - fp8_buffer = dtype_to_buffer_map[torch.uint8] - for bucket in fp8_buffer.buckets: - if len(bucket_groups) == len(fp8_buffer.buckets) - 1: - # The last bucket group. - group_buckets = [bucket] + non_fp8_buckets - else: - # The first N-1 bucket groups. - group_buckets = [bucket] - bucket_groups.append( - _ParamAndGradBucketGroup( - group_buckets, - buffer.ddp_config, - buffer.data_parallel_group, - buffer.data_parallel_world_size, - ) - ) - return bucket_groups +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import logging +import math +import warnings +from contextlib import nullcontext +from enum import Enum +from functools import partial +from typing import Dict, List, Optional + +import torch +from torch.distributed import _coalescing_manager + +from megatron.core.rerun_state_machine import get_rerun_state_machine + +from ..utils import is_float8tensor, is_torch_min_version, log_on_each_pipeline_stage +from .distributed_data_parallel_config import DistributedDataParallelConfig + +logger = logging.getLogger(__name__) + + +if is_torch_min_version("1.13.0"): + dist_all_gather_func = torch.distributed.all_gather_into_tensor + dist_reduce_scatter_func = torch.distributed.reduce_scatter_tensor +else: + dist_all_gather_func = torch.distributed._all_gather_base + dist_reduce_scatter_func = torch.distributed._reduce_scatter_base + + +class BufferType(Enum): + """ + Enumeration for buffer type. + """ + + PARAM = 1 + GRAD = 2 + + +def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int): + """ + Shard buffer into data_parallel_world_size chunks of equal size. + """ + assert buffer.numel() % data_parallel_world_size == 0 + shard_size = buffer.numel() // data_parallel_world_size + sharded_buffer = [ + buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size) + ] + return sharded_buffer + + +class _ParamAndGradBucket: + """ + Bucket to keep track of a subset of the model's parameters and gradients. + + Args: + params: List of parameters whose gradients are collated in this bucket. + param_data: View in _ParamAndGradBuffer.param_data that this bucket is responsible for. + grad_data: View in _ParamAndGradBuffer.grad_data that this bucket is responsible for. + offset: Offset of this bucket's view in the larger _ParamAndGradBuffer. + numel_unpadded: Number of unpadded elements in bucket. + gradient_scaling_factor: This factor is utilized to scale gradients prior to their + communication. Its application is twofold: it facilitates the averaging of gradients + and the scaling of gradients in the context of the Mixture of Experts (MoE) model. + bucket_id: Index of bucket in buffer. + """ + + def __init__( + self, + params: List[torch.nn.Parameter], + param_data: Optional[torch.Tensor], + grad_data: torch.Tensor, + offset: int, + numel_unpadded: int, + gradient_scaling_factor: float, + bucket_id: int, + ): + self.params_list = params + self.params = set(params) + # Make sure there are no duplicate params. + assert len(self.params_list) == len(self.params) + self.param_data = param_data + self.grad_data = grad_data + # The distributed optimizer needs to keep track of this bucket's offset + # within the full grad_buffer. + self.offset = offset + self.numel_unpadded = numel_unpadded + self.gradient_scaling_factor = gradient_scaling_factor + self.bucket_id = bucket_id + + +class _ParamAndGradBucketGroup: + """ + Put multiple buckets into a group so that their communications can be aggregated together. + Provides functionality to register when params in the bucket group have grads ready to be + synced; an asynchronous communication call is automatically launched when _all_ params in + the bucket group have grads ready. + + Args: + buckets: A list of buckets. + ddp_config: DistributedDataParallel config object. + collective_group: intra_distributed_optimizer_instance_group if using distributed + optimizer, data_parallel_group if not. + collective_group_size: World size using the intra data-parallel group. + """ + + def __init__( + self, + buckets: List[_ParamAndGradBucket], + ddp_config: DistributedDataParallelConfig, + collective_group: torch.distributed.ProcessGroup, + collective_group_size: int, + ): + self.buckets = buckets + self.ddp_config = ddp_config + + if self.ddp_config.use_distributed_optimizer: + self.intra_distributed_optimizer_instance_group = collective_group + self.intra_distributed_optimizer_instance_size = collective_group_size + self.intra_distributed_optimizer_instance_rank = torch.distributed.get_rank( + group=collective_group + ) + else: + self.data_parallel_group = collective_group + + # State for bookkeeping: params is the set of parameters this bucket group is + # responsible for, params_with_grad is the set of parameters with grads + # available. When overlap_grad_reduce is True, communication (all-reduce + # or reduce-scatter) is issued when params_with_grad equals params. + self.param_to_bucket = {} + self.params = set() + for bucket in self.buckets: + for param in bucket.params_list: + self.param_to_bucket[param] = bucket + self.params.add(param) + + self.next_param_gather_bucket_group = None + + if self.ddp_config.num_distributed_optimizer_instances > 1: + self.inter_distributed_optimizer_instance_group = None + self.communication_stream = None + + self.reset() + self.param_gather_handle = None + self.param_gather_dispatched = False + self.grad_reduce_handle = None + + def reset(self): + """ + Reset metadata in bucket group in preparation for the next iteration of training. + """ + self.params_with_grad = set() + self.is_last_microbatch = True + + def check_grads(self, check_for_nan_or_inf, check_for_large): + """ + Make sure norm of grads in bucket are not NaN prior to data-parallel + all-reduce / reduce-scatter. + """ + rerun_state_machine = get_rerun_state_machine() + for i in range(len(self.buckets)): + grad_norm = self.buckets[i].grad_data.norm(p=2) + # check for NaN, Inf and unexpectedly large grads + if check_for_nan_or_inf: + rerun_state_machine.validate_result( + result=grad_norm, + rejection_func=torch.isnan, + message=f"found NaN in local grad norm for bucket #{i} " + f"in backward pass before data-parallel communication collective", + tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward + fatal=True, + ) + rerun_state_machine.validate_result( + result=grad_norm, + rejection_func=torch.isinf, + message=f"found Inf in local grad norm for bucket #{i} " + f"in backward pass before data-parallel communication collective", + tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward + fatal=True, + ) + if check_for_large: + rerun_state_machine.validate_result( + result=grad_norm, + rejection_func=partial( + rerun_state_machine.is_unexpectedly_large, threshold=10, context="grads" + ), + message=f"found unexpected large grads in bucket #{i} " + f"in backward pass before data-parallel communication collective", + tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward + fatal=False, + ) + + def start_param_sync(self, force_sync: bool = False): + """ + Initiates all necessary param all-gathers for this bucket. + + When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous + communication call (unless force_sync is True). When ddp_config.overlap_param_gather + is set to False, makes synchronous call. + + Args: + force_sync (bool, optional): force synchronous collective regardless of + other settings if true. + """ + assert self.ddp_config.use_distributed_optimizer + + if force_sync: + if self.param_gather_handle is not None: + self.param_gather_handle.wait() + self.param_gather_handle = None + return + else: + assert self.param_gather_handle is None + + async_op = self.ddp_config.overlap_param_gather and not force_sync + # Coalesce communication kernels across buckets in the bucket group. + with _coalescing_manager( + self.intra_distributed_optimizer_instance_group, async_ops=async_op + ) as cm: + for bucket in self.buckets: + local_data_view = shard_buffer( + bucket.param_data, self.intra_distributed_optimizer_instance_size + )[self.intra_distributed_optimizer_instance_rank] + dist_all_gather_func( + bucket.param_data, + local_data_view, + group=self.intra_distributed_optimizer_instance_group, + async_op=async_op, + ) + if async_op: + self.param_gather_handle = cm + else: + # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, + # `cm` is not None, which is different from when `_coalescing_manager` is not used in + # which case the torch.distributed._all_gather_base() will return None. In order to + # maintain consistency with prior code, we need to manually set communication handle to + # None. + self.param_gather_handle = None + self.param_gather_dispatched = True + + def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): + """ + Finishes param sync communication operation for this bucket. Dispatches + next bucket's param sync if available, unless skip_next_bucket_dispatch + is True. + + When ddp_config.overlap_param_gather is set to True, waits for asynchronous + communication call to complete (and dispatches one if one is not already + outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to + False. + + Args: + skip_next_bucket_dispatch (bool, optional): if true, dispatch next + bucket's communication if available. + """ + assert self.ddp_config.use_distributed_optimizer + assert self.ddp_config.overlap_param_gather + + # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first + # AG bucket in first model chunk if ddp_config.align_param_gather is False). + if not self.param_gather_dispatched: + self.start_param_sync() + + if self.param_gather_handle is not None: + self.param_gather_handle.wait() + self.param_gather_handle = None + # Dispatch next bucket's asynchronous param AG only if it has not been dispatched yet. + if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch: + if self.next_param_gather_bucket_group.param_gather_dispatched: + warnings.warn( + "The next bucket's parameter all-gather operation has already been " + "dispatched. This may be caused by a mismatch between the order of " + "parameter registration and forward pass execution, which will " + "hurt the communication-computation overlap performance." + ) + else: + self.next_param_gather_bucket_group.start_param_sync() + + def start_grad_sync(self): + """ + Initiates grad sync (all-reduce or reduce-scatter) communication operations + for all buckets in the bucket group. + + When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous + communication call. When ddp_config.overlap_grad_reduce is set to False, makes + synchronous call. + """ + assert ( + self.grad_reduce_handle is None + ), 'Should not have multiple communication calls outstanding at once' + + if self.ddp_config.check_for_nan_in_grad or self.ddp_config.check_for_large_grads: + self.check_grads( + check_for_nan_or_inf=self.ddp_config.check_for_nan_in_grad, + check_for_large=self.ddp_config.check_for_large_grads, + ) + + # gradient_scaling_factor already takes into account whether we are computing + # an average or sum in the data-parallel collective. + for bucket in self.buckets: + if bucket.gradient_scaling_factor != 1.0: + bucket.grad_data *= bucket.gradient_scaling_factor + + # Decide reduce_op. + reduce_op = torch.distributed.ReduceOp.SUM + if self.ddp_config.average_in_collective: + reduce_op = torch.distributed.ReduceOp.AVG + + # We use the following stream synchronization for the gradient reduction + # within and across DistOpt instances. + + # Compute Stream: -------------Gradient compute------------------- + # Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)------- + # NCCL Stream: -------RS------ -------AR------ + + # Use async communications only when overlap_grad_reduce is True. + async_op = ( + self.ddp_config.overlap_grad_reduce + and self.ddp_config.num_distributed_optimizer_instances == 1 + ) + if ( + self.ddp_config.num_distributed_optimizer_instances > 1 + and self.ddp_config.overlap_grad_reduce + ): + # Assign a communication stream if we have multiple DistOpt instances and we + # need to overlap communication. + stream_context = torch.cuda.stream(self.communication_stream) + + # The RS/AR communication stream needs to wait for the default stream + # to complete its gradient computation before launching the next + # gradient reduction collective. + self.communication_stream.wait_stream(torch.cuda.default_stream()) + else: + stream_context = nullcontext() + + if self.ddp_config.use_distributed_optimizer: + communication_group = self.intra_distributed_optimizer_instance_group + else: + communication_group = self.data_parallel_group + + # Coalesce communication kernels across buckets in the bucket group. + with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm: + for bucket in self.buckets: + if self.ddp_config.use_distributed_optimizer: + local_data_view = shard_buffer( + bucket.grad_data, self.intra_distributed_optimizer_instance_size + )[self.intra_distributed_optimizer_instance_rank] + dist_reduce_scatter_func( + local_data_view, + bucket.grad_data, + op=reduce_op, + group=communication_group, + async_op=async_op, + ) + else: + torch.distributed.all_reduce( + bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op + ) + + # With multiple DistOpt instances, we need to all-reduce across instances. + if ( + self.ddp_config.use_distributed_optimizer + and self.ddp_config.num_distributed_optimizer_instances > 1 + ): + + # Create a new coalescing manager for the inter-instance all-reduce. + with stream_context, _coalescing_manager( + self.inter_distributed_optimizer_instance_group, async_ops=async_op + ) as cm: + for bucket in self.buckets: + local_data_view = shard_buffer( + bucket.grad_data, self.intra_distributed_optimizer_instance_size + )[self.intra_distributed_optimizer_instance_rank] + + torch.distributed.all_reduce( + local_data_view, + op=reduce_op, + group=self.inter_distributed_optimizer_instance_group, + async_op=async_op, + ) + + if async_op: + self.grad_reduce_handle = cm + else: + # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, + # `cm` is not None, which is different from when `_coalescing_manager` is not used in + # which case the torch.distributed._reduce_scatter_base() will return None. In order to + # maintain consistency with prior code, we need to manually set communication handle to + # None. + self.grad_reduce_handle = None + + def finish_grad_sync(self): + """ + Finishes grad sync (all-reduce or reduce-scatter) communication operations + for all buckets in the bucket group. + + When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous + communication call to complete. When ddp_config.overlap_grad_reduce is set to False, + makes synchronous call. + """ + self.param_gather_dispatched = False + # If overlap_grad_reduce is False, start (and finish) synchronous communication call here. + if not self.ddp_config.overlap_grad_reduce: + self.start_grad_sync() + return + # When using multiple DistOpt instances, we don't need to sync here as we launch + # communications on a separate communication stream. + if self.ddp_config.num_distributed_optimizer_instances > 1: + torch.cuda.default_stream().wait_stream(self.communication_stream) + return + assert self.grad_reduce_handle is not None, ( + f'Communication call has not been issued for this bucket ' + f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)' + ) + self.grad_reduce_handle.wait() + self.grad_reduce_handle = None + + def register_grad_ready(self, param: torch.nn.Parameter): + """ + Registers grads for the passed-in param to be "ready" for grad sync. + + When the number of microbatches is greater than 1, we only want to register + grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce + is True. + """ + assert ( + self.ddp_config.overlap_grad_reduce + ), 'register_grad_ready() should only be called when overlap_grad_reduce is True' + if self.is_last_microbatch: + assert param in self.param_to_bucket, 'Param is not in the bucket group' + assert param not in self.params_with_grad, 'Cannot set grad twice' + self.params_with_grad.add(param) + # If all params in bucket group have grads available, issue communication call. + if len(self.params_with_grad) == len(self.params): + self.start_grad_sync() + + +class _ParamAndGradBuffer: + """ + Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into + buckets with roughly `bucket_size` parameters each. + + Args: + ddp_config: DistributedDataParallel config object. + param_dtype: Type of param tensor. + grad_dtype: Type of grad tensor. + params: List of parameters whose parameters and gradients are collated in the underlying + tensor. + data_parallel_group: Data-parallel process group. + bucket_size: The rough size of each bucket in terms of number of parameters. + param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes). + gradient_scaling_factor: This factor is utilized to scale gradients prior to their + communication. Its application is twofold: it facilitates the averaging of gradients + and the scaling of gradients in the context of the Mixture of Experts (MoE) model. + param_indices: The index of each param among the params with same dtype, if a param is fp8, + use its "fake" high precision dtype to determine which params have same dtype with it. + These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode. + """ + + def __init__( + self, + ddp_config: DistributedDataParallelConfig, + param_dtype: torch.dtype, + grad_dtype: torch.dtype, + params: List[torch.nn.Parameter], + data_parallel_group: torch.distributed.ProcessGroup, + bucket_size: int, + param_to_name: Dict[torch.nn.Parameter, str], + gradient_scaling_factor: float, + param_indices: List[int], + ): + self.ddp_config = ddp_config + self.params = params + self.param_indices = param_indices + + # Check that params are unique. + unique_params = set() + for param in params: + assert param not in unique_params + unique_params.add(param) + del unique_params + + # Store attributes that will be needed later. + self.param_dtype = param_dtype + self.grad_dtype = grad_dtype + self.data_parallel_group = data_parallel_group + self.data_parallel_world_size = torch.distributed.get_world_size( + group=self.data_parallel_group + ) + self.gradient_scaling_factor = gradient_scaling_factor + + # Data structures to store underlying buckets and relevant indexing data. + self.buckets = [] + self.param_to_bucket = {} # Param -> bucket mapping. + self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer). + + def _pad(number_to_be_padded: int, divisor: int) -> int: + return int(math.ceil(number_to_be_padded / divisor) * divisor) + + def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int: + """ + Pads end index of bucket if using distributed optimizer (to ensure uniform sharding). + """ + if self.ddp_config.use_distributed_optimizer: + # Workaround for TE bug causing cuBLAS to pick an incompatible algorithm. + # This also helps cuBLAS pick more efficient algorithms for GEMMs. + # We now ensure that all buckets start at a memory address that is 256-byte + # aligned (128 values since params and grads use >= 16-bit precision). + if self.ddp_config.pad_buckets_for_high_nccl_busbw: + # Make sure the bucket size is divisible by a large power of 2 (2^16) to + # ensure NCCL collectives have high bus bandwidth at large DP counts, + # since NCCL message size (which for ring algorithms is bucket_size / + # dp_size) apparently needs to be divisible by a power of 2 for high busbw. + bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128, 2**16) + else: + bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128) + return _pad(bucket_end_index, bucket_size_divisor) + return bucket_end_index + + def _pad_start_of_param_if_needed(param_start_index: int) -> int: + """ + Pads start index of param if using distributed optimizer (to ensure "good" alignment). + """ + if self.ddp_config.use_distributed_optimizer: + # Ensure that params start at 128-byte aligned addresses (64 values + # since params are >= 16-bit precision). + return _pad(param_start_index, 64) + return param_start_index + + # First, figure out how many elements should be in the underlying buffer storage. + # Note that if we need to split the buffer into smaller buckets, each of these + # might need to be padded as well (if using the distributed optimizer). + param_start_index = 0 + bucket_start_index = param_start_index + bucket_params = set() + self.bucket_indices = [] + per_bucket_numel_unpadded = [] + bucket_id = 0 + + def _update_bucket_metadata(param_end_index: int) -> int: + """ + Record metadata for the bucket starting at bucket_start_index and ending with the + passed-in param_end_index. Returns the bucket's end_index. + """ + nonlocal bucket_start_index, bucket_params, bucket_id + per_bucket_numel_unpadded.append(param_end_index - bucket_start_index) + bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index) + + # Record metadata of new bucket. + self.bucket_indices.append((bucket_start_index, bucket_end_index)) + bucket_start_index = bucket_end_index + + # Prepare for next bucket. + bucket_params = set() + bucket_id += 1 + + # Return the potentially padded bucket_end_index. + return bucket_end_index + + def _does_param_require_new_bucket(param): + """ + Split shared embedding parameters into separate bucket if using distributed + optimizer that makes use of reduce-scatters instead of all-reduces. + This ensures that the first and last pipeline stage partition optimizer state + for the shared embedding parameters the same way across DP replicas, allowing + the DP reduce-scatter to be before the embedding all-reduce. + """ + return ( + getattr(param, "shared_embedding", False) + and self.ddp_config.use_distributed_optimizer + ) + + for param in params[::-1]: + # Iterate through parameters in reverse order to roughly follow backprop order. + + this_numel = param.data.nelement() + param_start_index = _pad_start_of_param_if_needed(param_start_index) + + # Create bucket with collected parameters if current param needs its own bucket. + if _does_param_require_new_bucket(param): + # We are creating a bucket for the already accumulated parameters, whose params + # end at the current param_start_index. + if self.ddp_config.use_distributed_optimizer: + # Make sure new bucket is appropriately padded. + if param_start_index % self.data_parallel_world_size != 0: + param_start_index = _pad_end_of_bucket_if_needed(param_start_index) + if len(bucket_params) > 0: + bucket_end_index = _update_bucket_metadata(param_start_index) + + param_end_index = param_start_index + this_numel + self.param_index_map[param] = (param_start_index, param_end_index, bucket_id) + bucket_params.add(param) + + # If we have enough elements already or the current param is part of the shared + # embedding layer and needs a separate bucket, form a new bucket. + if ( + bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size + ) or _does_param_require_new_bucket(param): + bucket_end_index = _update_bucket_metadata(param_end_index) + param_start_index = bucket_end_index + else: + param_start_index = param_end_index + + # Add remaining params to a new bucket. + if len(bucket_params) > 0: + bucket_end_index = _update_bucket_metadata(param_end_index) + + # Next, create underlying storage for buffer (with numel elements that includes + # padding as necessary). + self.numel = bucket_end_index + self.numel_unpadded = sum(per_bucket_numel_unpadded) + assert self.numel_unpadded <= self.numel + if self.ddp_config.use_distributed_optimizer: + assert self.numel % self.data_parallel_world_size == 0 + else: + assert self.numel == self.numel_unpadded + + self.param_data = None + # Only re-map param tensors if using distributed optimizer. + if self.ddp_config.use_distributed_optimizer: + self.param_data = torch.zeros( + self.numel, + dtype=self.param_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + self.grad_data = torch.zeros( + self.numel, + dtype=self.grad_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + # Finally, map param.data and param.main_grad fields to buffers. + bucket_params = [] + bucket_start_index = 0 + cur_bucket_id = 0 + for param in params[::-1]: + param_start_index, param_end_index, bucket_id = self.param_index_map[param] + + # Assign param.data to appropriate segment of self.param_data. + if self.param_data is not None: + old_param_data = param.data + new_param_data = self._get( + param.data.shape, param_start_index, buffer_type=BufferType.PARAM + ) + if is_float8tensor(param): + param._data = new_param_data + else: + param.data = new_param_data + assert old_param_data._base is None + # Copy tensor values (from initialization or checkpoint). + param.data.detach().copy_(old_param_data) + del old_param_data + + param.main_grad = self._get( + param.data.shape, param_start_index, buffer_type=BufferType.GRAD + ) + if bucket_id != cur_bucket_id: + bucket_end_index = _pad_end_of_bucket_if_needed(param_start_index) + self.buckets.append( + self._new_bucket( + bucket_params=bucket_params, + start_index=bucket_start_index, + end_index=bucket_end_index, + numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], + bucket_id=cur_bucket_id, + ) + ) + bucket_start_index = bucket_end_index + bucket_params = [] + assert cur_bucket_id + 1 == len(self.buckets) + assert bucket_id == cur_bucket_id + 1 + cur_bucket_id = bucket_id + bucket_params.append(param) + + # Add remaining params to a new bucket. + if len(bucket_params) > 0: + bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index) + self.buckets.append( + self._new_bucket( + bucket_params=bucket_params, + start_index=bucket_start_index, + end_index=bucket_end_index, + numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], + bucket_id=cur_bucket_id, + ) + ) + + # Log buckets for all PP stages. + log_strs = [] + log_strs.append( + f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}' + ) + for index, bucket in enumerate(self.buckets): + numel = 0 + for param in bucket.params: + numel += param.data.nelement() + log_strs.append( + f"Params for bucket {index+1} ({numel} elements, " + f"{bucket.grad_data.nelement()} padded size):" + ) + for param in bucket.params: + log_strs.append(f'\t{param_to_name[param]}') + log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs)) + + def scale_gradients(self, scaling_factor: float) -> None: + """Scale the gradient data by `scaling_factor`.""" + self.grad_data *= scaling_factor + + def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor: + """ + Return a tensor with the input `shape` as a view into the 1-D data starting at + `start_index`. + """ + end_index = start_index + shape.numel() + assert end_index <= self.numel, 'Requested tensor is out of buffer range' + if buffer_type == BufferType.PARAM: + assert self.param_data is not None + buffer_tensor = self.param_data[start_index:end_index] + elif buffer_type == BufferType.GRAD: + buffer_tensor = self.grad_data[start_index:end_index] + else: + raise Exception("Illegal buffer type provided to GradBuffer._get() function") + buffer_tensor = buffer_tensor.view(shape) + return buffer_tensor + + def _new_bucket( + self, + bucket_params: List[torch.nn.Parameter], + start_index: int, + end_index: int, + numel_unpadded: int, + bucket_id: int, + ) -> _ParamAndGradBucket: + """ + Helper function that creates a new bucket. Also updates param->bucket mapping. + """ + + # Assert that indices are correctly padded (if needed), and that bucket + # position is same as originally computed. + if self.ddp_config.use_distributed_optimizer: + assert start_index % self.data_parallel_world_size == 0 + assert end_index % self.data_parallel_world_size == 0 + assert (start_index, end_index) == self.bucket_indices[bucket_id] + + # Get appropriate view into global _ParamAndGradBuffer. + bucketed_param_data = None + if self.param_data is not None: + bucketed_param_data = self._get( + torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM + ) + bucketed_grad_data = self._get( + torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD + ) + bucket = _ParamAndGradBucket( + params=bucket_params, + param_data=bucketed_param_data, + grad_data=bucketed_grad_data, + offset=start_index, + numel_unpadded=numel_unpadded, + gradient_scaling_factor=self.gradient_scaling_factor, + bucket_id=bucket_id, + ) + for bucket_param in bucket_params: + assert bucket_param not in self.param_to_bucket + self.param_to_bucket[bucket_param] = bucket + + return bucket + + def reset(self): + """ + Zero out the underlying grad_buffer. + """ + self.grad_data.zero_() + + +def partition_buckets( + buffers: List[_ParamAndGradBuffer], force_single_bucket_group: bool = False +) -> List[_ParamAndGradBucketGroup]: + """ + Automatically regroup the buckets of input buffers and return a list of bucket groups. + + In some scenarios, we need to put buckets from different buffers into a group so that their + communication can be aggregated. + + For example, when there are both fp8 weights and bf16 biases in the model and virtual + pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket, + which doubles the number of communication kernels, and because of the use of + CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the + overlap of communication kernels with computation kernels. + + The grouping strategy is: + 1. If force_single_bucket_group is True, put all buckets across all buffers into a single + bucket group. + 2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers, + let each bucket group have only one bucket. + 3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets + into the last fp8 bucket group. + - Since the non-fp8 parameters (typically the biases of various layers) are relatively + small, they are likely to be grouped into a single non-fp8 bucket. + - The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to + the end of the model, while the last bucket corresponds to the beginning. + - If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the + reduce-scatter to synchronize gradients after the backward pass at the end of the model + has completed. This is because we need to wait for the non-fp8 params from the beginning + layers to obtain their gradients. + - Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue. + + Args: + buffers (list): list of input buffers. + single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer + into a single bucket group. + """ + + if len(buffers) == 0: + return [] + + dtype_to_buffer_map = {} + for buffer in buffers: + dtype = buffer.param_dtype + # Make sure that the param_dtype of any two buffers is different. + assert dtype not in dtype_to_buffer_map + dtype_to_buffer_map[dtype] = buffer + + # Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True. + if force_single_bucket_group: + buckets = [] + ddp_config = buffers[0].ddp_config + data_parallel_group = buffers[0].data_parallel_group + data_parallel_world_size = buffers[0].data_parallel_world_size + for buffer in buffers: + assert ddp_config == buffer.ddp_config + assert data_parallel_group == buffer.data_parallel_group + assert data_parallel_world_size == buffer.data_parallel_world_size + buckets.extend(buffer.buckets) + + bucket_group = _ParamAndGradBucketGroup( + buckets, ddp_config, data_parallel_group, data_parallel_world_size + ) + return [bucket_group] + + if torch.uint8 not in dtype_to_buffer_map: + # Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have + # only one bucket. + bucket_groups = [] + for buffer in buffers: + for bucket in buffer.buckets: + bucket_groups.append( + _ParamAndGradBucketGroup( + [bucket], + buffer.ddp_config, + buffer.data_parallel_group, + buffer.data_parallel_world_size, + ) + ) + return bucket_groups + else: + # Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group. + non_fp8_buckets = [] + for buffer in buffers: + if buffer.param_dtype != torch.uint8: + for bucket in buffer.buckets: + non_fp8_buckets.append(bucket) + + bucket_groups = [] + fp8_buffer = dtype_to_buffer_map[torch.uint8] + for bucket in fp8_buffer.buckets: + if len(bucket_groups) == len(fp8_buffer.buckets) - 1: + # The last bucket group. + group_buckets = [bucket] + non_fp8_buckets + else: + # The first N-1 bucket groups. + group_buckets = [bucket] + bucket_groups.append( + _ParamAndGradBucketGroup( + group_buckets, + buffer.ddp_config, + buffer.data_parallel_group, + buffer.data_parallel_world_size, + ) + ) + return bucket_groups diff --git a/megatron/core/distributed/torch_fully_sharded_data_parallel.py b/megatron/core/distributed/torch_fully_sharded_data_parallel.py index 6d2e84e77b8656d19466a0bd80860bdfad47a851..40a840e35ccdbee58c056f4107d9a968785c5830 100644 --- a/megatron/core/distributed/torch_fully_sharded_data_parallel.py +++ b/megatron/core/distributed/torch_fully_sharded_data_parallel.py @@ -1,115 +1,123 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -from typing import List - -import torch - -try: - from torch.distributed import DeviceMesh - from torch.distributed._composable.fsdp import fully_shard - - HAVE_FSDP = True -except ImportError: - HAVE_FSDP = False - -from .. import parallel_state, tensor_parallel -from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding -from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from ..transformer.transformer_config import TransformerConfig -from ..transformer.transformer_layer import TransformerLayer -from .data_parallel_base import _BaseDataParallel - - -class TorchFullyShardedDataParallel(_BaseDataParallel): - """ - Enables fully sharded data parallelism by wrapping the given model with - the PyTorch FSDP2 API: - https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md - To utilize this class, PyTorch version >= 2.4.0 is required. - - Args: - config: Transformer config object. - module: Underlying model. - sub_modules_to_wrap: List of sub_modules to shard with FSDP. - Parameters within each sub_module will be all-gathered just-in-time. - The default list includes the following submodules derived from the - GPT model architecture: - TransformerLayer (all Transformer layers) - LanguageModelEmbedding (initial embedding layer) - RotaryEmbedding (initial RoPE layer) - tensor_parallel.ColumnParallelLinear (final output layer) - """ - - def __init__( - self, - config: TransformerConfig, - module: torch.nn.Module, - sub_modules_to_wrap: List[torch.nn.Module] = [ - TransformerLayer, - LanguageModelEmbedding, - RotaryEmbedding, - tensor_parallel.ColumnParallelLinear, - ], - **kwargs - ): - - assert ( - HAVE_FSDP - ), 'TorchFullyShardedDataParallel requires PyTorch >= 2.4.0 with FSDP 2 support.' - - super().__init__(config=config, module=module) - self.data_parallel_group = parallel_state.get_data_parallel_group( - with_context_parallel=True - ) - - mesh = DeviceMesh.from_group(self.data_parallel_group, "cuda") - - kwargs = {"mesh": mesh} - - def save_custom_attrs(module): - custom_attrs = {} - for name, param in module.named_parameters(): - attrs = vars(param) - custom_attrs[name] = {k: v for k, v in attrs.items()} - return custom_attrs - - def restore_custom_attrs(module, custom_attrs): - for name, param in module.named_parameters(): - if name in custom_attrs: - for attr_name, attr_value in custom_attrs[name].items(): - setattr(param, attr_name, attr_value) - - # Save the custom attributes on Parameters before FSDP overwrites them. - # See https://github.com/pytorch/pytorch/issues/136929. - attrs = save_custom_attrs(self.module) - - prev_module = None - for sub_module in self.module.modules(): - # Wrap individual submodules to fetch parameters just-in-time rather than - # conservatively fetching all parameters at the start of each iteration. - # See https://github.com/pytorch/pytorch/issues/114299. - if any( - isinstance(sub_module, sub_module_to_wrap) - for sub_module_to_wrap in sub_modules_to_wrap - ): - fully_shard(sub_module, **kwargs) - - # Explicitly set the FSDP backward prefetch schedule to prevent activation - # recomputation from disrupting the automatically generated default schedule. - if config.recompute_granularity is not None: - sub_module.set_modules_to_backward_prefetch( - [prev_module] if prev_module else [] - ) - prev_module = sub_module - - # Wrap the root module as required by the FSDP API. - # See https://github.com/pytorch/pytorch/issues/114299. - fully_shard(self.module, **kwargs) - - restore_custom_attrs(self.module, attrs) - - def load_state_dict(self, state_dict, strict=True): - """ - No-op because tensors are already loaded in-place by - `_load_base_checkpoint` with FSDP2.""" - pass +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import List + +import torch + +try: + from torch.distributed import DeviceMesh + from torch.distributed._composable.fsdp import fully_shard + + HAVE_FSDP = True +except ImportError: + HAVE_FSDP = False + +from megatron.core.utils import is_float8tensor + +from .. import parallel_state, tensor_parallel +from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from ..transformer.transformer_config import TransformerConfig +from ..transformer.transformer_layer import TransformerLayer +from .data_parallel_base import _BaseDataParallel +from .distributed_data_parallel_config import DistributedDataParallelConfig + + +class TorchFullyShardedDataParallel(_BaseDataParallel): + """ + Enables fully sharded data parallelism by wrapping the given model with + the PyTorch FSDP2 API: + https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md + To utilize this class, PyTorch version >= 2.4.0 is required. + + Args: + config: Transformer config object. + ddp_config: DistributedDataParallel config object. + module: Underlying model. + sub_modules_to_wrap: List of sub_modules to shard with FSDP. + Parameters within each sub_module will be all-gathered just-in-time. + The default list includes the following submodules derived from the + GPT model architecture: + TransformerLayer (all Transformer layers) + LanguageModelEmbedding (initial embedding layer) + RotaryEmbedding (initial RoPE layer) + tensor_parallel.ColumnParallelLinear (final output layer) + """ + + def __init__( + self, + config: TransformerConfig, + ddp_config: DistributedDataParallelConfig, + module: torch.nn.Module, + sub_modules_to_wrap: List[torch.nn.Module] = [ + TransformerLayer, + LanguageModelEmbedding, + RotaryEmbedding, + tensor_parallel.ColumnParallelLinear, + ], + ): + + assert ( + HAVE_FSDP + ), 'TorchFullyShardedDataParallel requires PyTorch >= 2.4.0 with FSDP 2 support.' + + super().__init__(config=config, module=module) + self.data_parallel_group = parallel_state.get_data_parallel_group( + with_context_parallel=True + ) + + kwargs = {"mesh": DeviceMesh.from_group(self.data_parallel_group, "cuda")} + + def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + attrs = vars(param) + if is_float8tensor(param): + # disable fp8 transpose cache and perform transposing fp8 weights + # at each micro-batch because torch-FSDP doesn't recognize the + # micro-batch id, thus removing unnecessary memory stores + attrs['_fp8_attrs']['transpose_invalid'] = False + del attrs['_fp8_attrs']['transpose'] + custom_attrs[name] = {k: v for k, v in attrs.items()} + return custom_attrs + + def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + # Save the custom attributes on Parameters before FSDP overwrites them. + # See https://github.com/pytorch/pytorch/issues/136929. + attrs = save_custom_attrs(self.module) + + prev_module = None + for sub_module in self.module.modules(): + # Wrap individual submodules to fetch parameters just-in-time rather than + # conservatively fetching all parameters at the start of each iteration. + # See https://github.com/pytorch/pytorch/issues/114299. + if any( + isinstance(sub_module, sub_module_to_wrap) + for sub_module_to_wrap in sub_modules_to_wrap + ): + fully_shard(sub_module, **kwargs) + + # Explicitly set the FSDP backward prefetch schedule to prevent activation + # recomputation from disrupting the automatically generated default schedule. + if config.recompute_granularity is not None: + sub_module.set_modules_to_backward_prefetch( + [prev_module] if prev_module else [] + ) + prev_module = sub_module + + # Wrap the root module as required by the FSDP API. + # See https://github.com/pytorch/pytorch/issues/114299. + fully_shard(self.module, **kwargs) + + restore_custom_attrs(self.module, attrs) + + def load_state_dict(self, state_dict, strict=True): + """ + No-op because tensors are already loaded in-place by + `_load_base_checkpoint` with FSDP2.""" + pass diff --git a/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py b/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py index 7a1401fb2416ece694b1d33dd6d61a5027bedb07..d3cd7ff296f1d9807b2bc94ac2e45fdcdbab9600 100644 --- a/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +++ b/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py @@ -21,6 +21,10 @@ DEFAULT_CONVERSION_DICT = { 'decoder.layers.mlp.linear_fc1.bias': TRTLLMLayers.mlp_fc_bias, 'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight, 'decoder.layers.mlp.linear_fc2.bias': TRTLLMLayers.mlp_projection_bias, + # EXPERTS + 'decoder.layers.mlp.experts.experts.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight_mixture_of_experts, + 'decoder.layers.mlp.experts.experts.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight_mixture_of_experts, + 'decoder.layers.mlp.router.weight': TRTLLMLayers.mlp_router_weight, # FINAL LAYER NORM 'decoder.final_layernorm.weight': TRTLLMLayers.final_layernorm_weight, 'decoder.final_layernorm.bias': TRTLLMLayers.final_layernorm_bias, diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index a89e272e51ef9a12b1421c3054f40f95f3c53d15..29914b89373bfea06517547e607d4936e836f901 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1,1273 +1,1359 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -import dataclasses -import io -import os -import pickle -import warnings -from typing import Callable - -import torch -import transformer_engine as te -from packaging.version import Version as PkgVersion -from torch import Tensor -from torch.nn.parameter import Parameter - -from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding -from megatron.core.model_parallel_config import ModelParallelConfig -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.parallel_state import ( - get_context_parallel_global_ranks, - get_context_parallel_group, - get_expert_data_parallel_rank, - get_expert_model_parallel_rank, - get_expert_model_parallel_world_size, - get_expert_tensor_parallel_group, - get_expert_tensor_parallel_rank, - get_expert_tensor_parallel_world_size, - get_hierarchical_context_parallel_groups, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name -from megatron.core.tensor_parallel.layers import ( - _initialize_affine_weight_cpu, - set_tensor_model_parallel_attributes, -) -from megatron.core.tensor_parallel.utils import divide -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint -from megatron.core.utils import get_te_version, is_te_min_version - - -def _get_extra_te_kwargs(config: TransformerConfig): - extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype} - - if is_te_min_version("0.12.0"): - if config.use_cpu_initialization: - extra_transformer_engine_kwargs["device"] = 'cpu' - else: - extra_transformer_engine_kwargs["device"] = torch.cuda.current_device() - return extra_transformer_engine_kwargs - - -def condition_init_method(config, init_method): - """Condition TE init_method on config.perform_initialization.""" - return init_method if config.perform_initialization else (lambda w: None) - - -class TENorm: - """ - A conditional wrapper to initialize an instance of Transformer-Engine's - `LayerNorm` or `RMSNorm` based on input - """ - - # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? - def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): - if config.normalization == "LayerNorm": - instance = te.pytorch.LayerNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), - ) - elif config.normalization == "RMSNorm": - assert hasattr( - te.pytorch, "RMSNorm" - ), "Transformer-Engine >= v0.11 required to use this feature" - instance = te.pytorch.RMSNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), - ) - else: - raise Exception('Only LayerNorm and RMSNorm are curently supported') - - return instance - - -class TELinear(te.pytorch.Linear): - """ - Wrapper for the Transformer-Engine's `Linear` layer. - - Note that if Megatron's parallel_state has not been initialized - yet, the tp_group passed to TE will be None and must be set later - via set_tensor_parallel_group(). - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - parallel_mode: str, - config: ModelParallelConfig, - init_method: Callable, - bias: bool, - skip_bias_add: bool, - skip_weight_param_allocation: bool, - tp_comm_buffer_name: str = None, - is_expert: bool = False, - ): - self.config = config - - # TE returns a zero length Tensor when bias=False and - # return_bias=True, but we prefer None. So in that case we - # tell TE to not return the bias, and return None - # ourselves. This way our forward always returns two values - # and we don't have to deal with the zero length Tensor. - self.te_return_bias = skip_bias_add and bias - self.is_first_microbatch = True - self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache - if skip_weight_param_allocation: - raise ValueError( - 'Transformer Engine linear layers do not support skip_weight_param_allocation' - ) - - extra_kwargs = _get_extra_te_kwargs(config) - - if is_te_min_version("0.8.0"): - if self.config.tp_comm_overlap: - if is_te_min_version("1.5.0"): - # Use old overlap flags if they were supplied instead - extra_kwargs["ub_overlap_ag"] = ( - self.config.tp_comm_overlap_ag - if hasattr(self.config, "tp_comm_overlap_ag") - else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag - ) - extra_kwargs["ub_overlap_rs"] = ( - self.config.tp_comm_overlap_rs - if hasattr(self.config, "tp_comm_overlap_rs") - else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs - ) - # Disable ub overlap for experts. - if is_expert: - extra_kwargs["ub_overlap_ag"] = False - extra_kwargs["ub_overlap_rs"] = False - else: - extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag - extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag - extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs - extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs - # Disable ub overlap for experts. - if is_expert: - extra_kwargs["ub_split_ag"] = False - extra_kwargs["ub_atomic_gemm_ag"] = False - extra_kwargs["ub_split_rs"] = False - extra_kwargs["ub_atomic_gemm_rs"] = False - if is_te_min_version("1.0.0", check_equality=False): - assert ( - tp_comm_buffer_name is not None - ), "Buffer name should be set to configure communication overlap settings" - extra_kwargs["ub_name"] = tp_comm_buffer_name - - self.expert_parallel = self.config.expert_model_parallel_size > 1 - if is_expert: - rng_tracker_name = get_expert_parallel_rng_tracker_name() - else: - rng_tracker_name = None - if is_te_min_version("1.7.0"): - extra_kwargs["rng_tracker_name"] = rng_tracker_name - - # Disable communications in TE when using TP or EP by making TE agnostic of model parallel. - if is_expert: - tp_group = get_expert_tensor_parallel_group(check_initialized=False) - tp_size = get_expert_tensor_parallel_world_size() - else: - tp_group = get_tensor_model_parallel_group(check_initialized=False) - tp_size = get_tensor_model_parallel_world_size() - explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) - - if explicit_expert_comm: - if parallel_mode == "column": - output_size = divide(output_size, tp_size) - elif parallel_mode == "row": - input_size = divide(input_size, tp_size) - parallel_mode = None - tp_size = 1 - tp_group = None - - super().__init__( - in_features=input_size, - out_features=output_size, - sequence_parallel=self.config.sequence_parallel, - fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, - tp_group=tp_group, - tp_size=tp_size, - get_rng_state_tracker=( - get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None - ), - init_method=condition_init_method(config, init_method), - bias=bias, - return_bias=self.te_return_bias, - parallel_mode=parallel_mode, - **extra_kwargs, - ) - - for param in self.parameters(): - setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) - - def forward(self, x): - """Forward.""" - _is_first_microbatch = ( - None if self.disable_parameter_transpose_cache else self.is_first_microbatch - ) - out = super().forward(x, is_first_microbatch=_is_first_microbatch) - self.is_first_microbatch = False - - # TE only returns a tuple when return_bias is True, otherwise - # it returns a single Tensor, we always want to return two - # values regardless of the arguments. - if self.te_return_bias: - return out - return out, None - - -class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): - """ - Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines - layernorm and linear layers - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - config: TransformerConfig, - init_method: Callable, - gather_output: bool, - bias: bool, - skip_bias_add: bool, - is_expert: bool, - skip_weight_param_allocation: bool = False, - tp_comm_buffer_name: str = None, - ): - self.config = config - - if gather_output: - raise ValueError('Transformer Engine linear layers do not support gather_output = True') - - if is_expert: - raise ValueError('Transformer Engine linear layers do not yet support MoE') - - if skip_weight_param_allocation: - raise ValueError( - 'Transformer Engine linear layers do not support skip_weight_param_allocation' - ) - - # TE returns a zero length Tensor when bias=False and - # return_bias=True, but we prefer None. So in that case we - # tell TE to not return the bias, and return None - # ourselves. This way our forward always returns two values - # and we don't have to deal with the zero length Tensor. - self.te_return_bias = skip_bias_add and bias - self.is_first_microbatch = True - self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache - extra_kwargs = _get_extra_te_kwargs(config) - - # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` - if is_te_min_version("0.11.0"): - extra_kwargs["normalization"] = self.config.normalization - elif self.config.normalization != "LayerNorm": - te_version = get_te_version() - raise ValueError( - f"Transformer Engine v{te_version} does not support {self.config.normalization}." - ) - - if is_te_min_version("0.8.0"): - if self.config.tp_comm_overlap: - extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad - extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad - if is_te_min_version("1.5.0", check_equality=False): - # Use old overlap flags if they were supplied instead - extra_kwargs["ub_overlap_ag"] = ( - self.config.tp_comm_overlap_ag - if hasattr(self.config, "tp_comm_overlap_ag") - else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag - ) - if is_te_min_version("1.6.0.dev0", check_equality=False): - extra_kwargs["ub_overlap_rs_dgrad"] = ( - self.config.tp_comm_overlap_rs_dgrad - if hasattr(self.config, "tp_comm_overlap_rs_dgrad") - else False - ) - if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv: - extra_kwargs["ub_overlap_ag"] = False - extra_kwargs["ub_overlap_rs_dgrad"] = False - - if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1: - extra_kwargs["ub_overlap_ag"] = False - extra_kwargs["ub_overlap_rs_dgrad"] = False - else: - extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag - extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag - if is_te_min_version("1.0.0", check_equality=False): - assert ( - tp_comm_buffer_name is not None - ), "Buffer name should be set to configure communication overlap settings" - extra_kwargs["ub_name"] = tp_comm_buffer_name - - super().__init__( - in_features=input_size, - out_features=output_size, - eps=self.config.layernorm_epsilon, - sequence_parallel=self.config.sequence_parallel, - fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, - tp_group=get_tensor_model_parallel_group(check_initialized=False), - tp_size=self.config.tensor_model_parallel_size, - get_rng_state_tracker=( - get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None - ), - init_method=( - condition_init_method(config, init_method) - if not config.use_cpu_initialization - else lambda w: None - ), - bias=bias, - return_bias=self.te_return_bias, - parallel_mode="column", - return_layernorm_output=False, - zero_centered_gamma=self.config.layernorm_zero_centered_gamma, - **extra_kwargs, - ) - - world_size = get_tensor_model_parallel_world_size() - rank = get_tensor_model_parallel_rank() - - if config.use_cpu_initialization: - output_size_per_partition = divide(output_size, world_size) - _ = _initialize_affine_weight_cpu( - self.weight, - output_size, - input_size, - output_size_per_partition, - 0, - init_method=condition_init_method(config, init_method), - stride=1, - return_master_weight=False, - rank=rank, - world_size=world_size, - skip_set_tensor_parallel_attributes=True, - ) - if bias: - self.bias = Parameter( - torch.empty(output_size_per_partition, dtype=config.params_dtype) - ) - set_tensor_model_parallel_attributes(self.bias, True, 0, 1) - with torch.no_grad(): - self.bias.zero_() - setattr(self.bias, 'allreduce', True) - - def forward(self, x): - """Forward.""" - _is_first_microbatch = ( - None if self.disable_parameter_transpose_cache else self.is_first_microbatch - ) - out = super().forward(x, is_first_microbatch=_is_first_microbatch) - self.is_first_microbatch = False - - # TE only returns a tuple when return_bias is True, otherwise - # it returns a single Tensor, we always want to return two - # values regardless of the arguments. - if self.te_return_bias: - return out - return out, None - - def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): - """Sharding along axis 0, bias sharded""" - state_dict = self.state_dict(prefix='', keep_vars=True) - return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets - ) - - -class TEColumnParallelLinear(TELinear): - """ - Wrapper for the Transformer-Engine's `Linear` layer but specialized similar - to megatron's `ColumnParallelLinear` layer. - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - config: ModelParallelConfig, - init_method: Callable, - gather_output: bool, - bias: bool, - skip_bias_add: bool, - is_expert: bool, - skip_weight_param_allocation: bool = False, - tp_comm_buffer_name: str = None, - ): - if gather_output: - raise ValueError('Transformer Engine linear layers do not support gather_output = True') - - super().__init__( - input_size=input_size, - output_size=output_size, - parallel_mode="column", - config=config, - init_method=( - condition_init_method(config, init_method) - if not config.use_cpu_initialization - else lambda w: None - ), - bias=bias, - skip_bias_add=skip_bias_add, - is_expert=is_expert, - skip_weight_param_allocation=skip_weight_param_allocation, - tp_comm_buffer_name=tp_comm_buffer_name, - ) - - if config.use_cpu_initialization: - if is_expert: - world_size = get_expert_tensor_parallel_world_size() - rank = get_expert_tensor_parallel_rank() - else: - world_size = get_tensor_model_parallel_world_size() - rank = get_tensor_model_parallel_rank() - output_size_per_partition = divide(output_size, world_size) - _ = _initialize_affine_weight_cpu( - self.weight, - output_size, - input_size, - output_size_per_partition, - 0, - init_method=condition_init_method(config, init_method), - stride=1, - return_master_weight=False, - rank=rank, - world_size=world_size, - skip_set_tensor_parallel_attributes=True, - ) - if bias: - self.bias = Parameter( - torch.empty(output_size_per_partition, dtype=config.params_dtype) - ) - set_tensor_model_parallel_attributes(self.bias, True, 0, 1) - with torch.no_grad(): - self.bias.zero_() - setattr(self.bias, 'allreduce', True) - - def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): - """Sharding along axis 0, bias sharded""" - state_dict = self.state_dict(prefix='', keep_vars=True) - return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets - ) - - -class TERowParallelLinear(TELinear): - """ - Wrapper for the Transformer-Engine's `Linear` layer but specialized similar - to megatron's `RowParallelLinear` layer. - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - config: ModelParallelConfig, - init_method: Callable, - bias: bool, - input_is_parallel: bool, - skip_bias_add: bool, - is_expert: bool, - tp_comm_buffer_name: str = None, - ): - if not input_is_parallel: - raise ValueError( - "Transformer Engine linear layers do not support input_is_parallel = False" - ) - - super().__init__( - input_size=input_size, - output_size=output_size, - parallel_mode="row", - config=config, - init_method=( - condition_init_method(config, init_method) - if not config.use_cpu_initialization - else lambda w: None - ), - bias=bias, - skip_bias_add=skip_bias_add, - skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long - is_expert=is_expert, - tp_comm_buffer_name=tp_comm_buffer_name, - ) - if config.use_cpu_initialization: - if is_expert: - world_size = get_expert_tensor_parallel_world_size() - rank = get_expert_tensor_parallel_rank() - else: - world_size = get_tensor_model_parallel_world_size() - rank = get_tensor_model_parallel_rank() - input_size_per_partition = divide(input_size, world_size) - self.master_weight = _initialize_affine_weight_cpu( - self.weight, - output_size, - input_size, - input_size_per_partition, - 1, - init_method=condition_init_method(config, init_method), - stride=1, - return_master_weight=False, - params_dtype=config.params_dtype, - rank=rank, - world_size=world_size, - skip_set_tensor_parallel_attributes=True, - ) - if bias: - self.bias = Parameter(torch.empty(output_size, dtype=config.params_dtype)) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - setattr(self.bias, 'allreduce', True) - setattr(self.bias, 'sequence_parallel', config.sequence_parallel) - - def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): - """Sharding along axis 1, bias not sharded""" - state_dict = self.state_dict(prefix='', keep_vars=True) - return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {'weight': 1}, sharded_offsets - ) - - -class TEDotProductAttention(te.pytorch.DotProductAttention): - """ - Wrapper for the Transformer-Engine's `DotProductAttention` layer that also - has "flash attention" enabled. - - Note that if Megatron's parallel_state has not been initialized yet, the - tp_group and cp_group passed to TE will be None and must be set later - via set_tensor_parallel_group() and set_context_parallel_group(). - """ - - cp_stream: torch.cuda.Stream = None - - def __init__( - self, - config: TransformerConfig, - layer_number: int, - attn_mask_type: AttnMaskType, - attention_type: str, - attention_dropout: float = None, - softmax_scale: float = None, - k_channels: int = None, - v_channels: int = None, - cp_comm_type: str = "p2p", - ): - self.config = config - self.te_forward_mask_type = False - self.qkv_format: str = 'sbhd' - - if self.config.apply_query_key_layer_scaling != bool( - int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0')) - ): - raise ValueError( - f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} " - f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is " - f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support " - f"setting query key layer scaling via argument, so these two must match." - ) - - extra_kwargs = {} - if is_te_min_version("0.11.0"): - extra_kwargs["num_gqa_groups"] = self.config.num_query_groups - elif self.config.num_query_groups != self.config.num_attention_heads: - raise ValueError( - f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, " - f"use a newer version of Transformer Engine. " - f"(num_query_groups ({self.config.num_query_groups}) != " - f"num_attention_heads ({self.config.num_attention_heads}))" - ) - - if is_te_min_version("0.10.0"): - extra_kwargs["attention_type"] = attention_type - # older version don't need attention_type - - if is_te_min_version("0.12.0", check_equality=False): - self.te_forward_mask_type = True - - # This check is important as CP config can be disabled while having a valid CP group - # Example - Disabling CP for encoder while a valid CP group exists for decoder - if self.config.context_parallel_size > 1: - assert is_te_min_version( - "1.0.0" - ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!" - if getattr(TEDotProductAttention, "cp_stream") is None: - TEDotProductAttention.cp_stream = torch.cuda.Stream() - extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) - extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks( - check_initialized=False - ) - extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream - if is_te_min_version("1.10.0"): - if cp_comm_type is None: - extra_kwargs["cp_comm_type"] = "p2p" - elif cp_comm_type == "a2a+p2p": - assert is_te_min_version("1.12.0"), ( - f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support" - "hierarchical cp commucation." - ) - extra_kwargs["cp_comm_type"] = "a2a+p2p" - extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups( - check_initialized=False - ) - else: - extra_kwargs["cp_comm_type"] = cp_comm_type - - if self.config.deterministic_mode: - if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0: - raise RuntimeError( - "deterministic_mode is on and we are using DotProductAttention from " - "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. " - f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}." - ) - - if config.window_size is not None: - # Check version - assert is_te_min_version("1.2.0"), ( - f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" - "sliding window attention." - ) - extra_kwargs['window_size'] = config.window_size - - if is_te_min_version("1.10.0"): - # TE 1.10.0 introduces the ability to set the different k and v channels - kv_channels = ( - (k_channels, v_channels) - if k_channels is not None and v_channels is not None - else self.config.kv_channels - ) - extra_kwargs['softmax_scale'] = softmax_scale - else: - kv_channels = self.config.kv_channels - - self.kept_packed_seq_params = set( - field.name for field in dataclasses.fields(PackedSeqParams) - ) - if get_te_version() < PkgVersion("1.3.0"): - # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H - # copies (#555) - # These two arguments did not exist prior to 1.3.0 - self.kept_packed_seq_params.discard("max_seqlen_q") - self.kept_packed_seq_params.discard("max_seqlen_kv") - - if get_te_version() < PkgVersion("1.10.0"): - # TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted - # in each individual sequence in THD format dataset - # These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012) - self.kept_packed_seq_params.discard("cu_seqlens_q_padded") - self.kept_packed_seq_params.discard("cu_seqlens_kv_padded") - - super().__init__( - num_attention_heads=self.config.num_attention_heads, - kv_channels=kv_channels, - attention_dropout=( - self.config.attention_dropout if attention_dropout is None else attention_dropout - ), - attn_mask_type=attn_mask_type.name, - sequence_parallel=self.config.sequence_parallel, - tp_size=self.config.tensor_model_parallel_size, - get_rng_state_tracker=( - get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None - ), - tp_group=get_tensor_model_parallel_group(check_initialized=False), - layer_number=layer_number, - **extra_kwargs, - ) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - attention_mask: Tensor, - attn_mask_type: AttnMaskType, - attention_bias: Tensor = None, - packed_seq_params: PackedSeqParams = None, - ): - """Forward.""" - packed_seq_kwargs = ( - {key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params} - if packed_seq_params is not None - else {} - ) - # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set - # after init - if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False): - self.qkv_format = 'bshd' - - qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format) - - # WAR for peak memory usage. - # See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388 - if self.config.apply_rope_fusion and qkv_format == 'bshd': - query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)] - # In PyTorch, the following two tensors are in fact the same: - # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1) - # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1) - # Stride for a dimension that is 1 has no meaning, so tensors created two different ways - # can have same shape but different strides. - # We unify them to the first one to pass the stride check in TE - if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride(): - value = value.as_strided(value.shape, key.stride()) - - attention_bias_kwargs = {} - if attention_bias is not None: - assert is_te_min_version("1.2.0"), ( - f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" - "`attention_bias`." - ) - attention_bias_kwargs = dict( - core_attention_bias_type='post_scale_bias', core_attention_bias=attention_bias - ) - - if self.te_forward_mask_type: - if qkv_format == 'thd' and is_te_min_version("1.7.0"): - # thd format uses flash attention with cuDNN kernel which requires is_padding=True, - # so the only acceptable mask types are `padding_causal` and `padding`. These do not - # necessarily indicate there are padded tokens in the sequence. - if attn_mask_type == AttnMaskType.causal: - attn_mask_type = AttnMaskType.padding_causal - elif attn_mask_type == AttnMaskType.no_mask: - attn_mask_type = AttnMaskType.padding - core_attn_out = super().forward( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type.name, - **attention_bias_kwargs, - **packed_seq_kwargs, - ) - else: - core_attn_out = super().forward( - query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs - ) - - if self.config.apply_rope_fusion and qkv_format == 'bshd': - return core_attn_out.transpose(0, 1) - else: - return core_attn_out - - -if is_te_min_version("1.9.0.dev0"): - - class TEGroupedLinear(te.pytorch.BatchLinear if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')) else te.pytorch.GroupedLinear): - """ - Wrapper for the Transformer-Engine's `GroupedLinear` layer. - - Note that if Megatron's parallel_state has not been initialized - yet, the tp_group passed to TE will be None and must be set later - via set_tensor_parallel_group(). - """ - - def __init__( - self, - num_gemms: int, - input_size: int, - output_size: int, - *, - parallel_mode: str, - config: ModelParallelConfig, - init_method: Callable, - bias: bool, - skip_bias_add: bool, - is_expert: bool = False, - tp_comm_buffer_name: str = None, - ): - self.config = config - - # TE returns a zero length Tensor when bias=False and - # return_bias=True, but we prefer None. So in that case we - # tell TE to not return the bias, and return None - # ourselves. This way our forward always returns two values - # and we don't have to deal with the zero length Tensor. - self.te_return_bias = skip_bias_add and bias - self.is_first_microbatch = True - self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache - - extra_kwargs = _get_extra_te_kwargs(config) - extra_kwargs["ub_name"] = tp_comm_buffer_name - - self.expert_parallel = self.config.expert_model_parallel_size > 1 - if is_expert: - extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name() - - # The comms between TP and EP group is explicitly handled by MoE token dispatcher. - # So we disable comms by making TE agnostic of model parallel. - if is_expert: - tp_group = get_expert_tensor_parallel_group(check_initialized=False) - tp_size = get_expert_tensor_parallel_world_size() - else: - tp_group = get_tensor_model_parallel_group(check_initialized=False) - tp_size = get_tensor_model_parallel_world_size() - self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) - - if self.explicit_expert_comm: - if parallel_mode == "column": - output_size = divide(output_size, tp_size) - elif parallel_mode == "row": - input_size = divide(input_size, tp_size) - parallel_mode = None - tp_size = 1 - tp_group = None - - super().__init__( - num_gemms=num_gemms, - in_features=input_size, - out_features=output_size, - sequence_parallel=self.config.sequence_parallel, - fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, - tp_group=tp_group, - tp_size=tp_size, - get_rng_state_tracker=( - get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None - ), - init_method=condition_init_method(config, init_method), - bias=bias, - return_bias=self.te_return_bias, - parallel_mode=parallel_mode, - **extra_kwargs, - ) - - for param in self.parameters(): - setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) - - def merge_extra_states( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - """ - Merge multiple "_extra_state" into one. - """ - self.init_fp8_metadata(num_gemms=self.num_gemms) - fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration - - try: - state_list = [ - state_dict.pop(f"{prefix}_extra_state{i}") for i in range(1, self.num_gemms) - ] - except KeyError: - # "_extra_state{i}" only exists for dist-ckpt. Return for torch native ckpt. - return - - if not fp8_checkpoint: - return - state_list = [state_dict.pop(f"{prefix}_extra_state")] + state_list - state_list = [self._decode_extra_state(state) for state in state_list] - extra_fp8_variables = state_list[0]['extra_fp8_variables'] - extra_fp8_variables['num_gemms'] = self.num_gemms - extra_state = { - "scale_fwd": torch.cat( - [state['scale_fwd'].view(-1, 1) for state in state_list], dim=1 - ).view(-1), - "scale_inv_fwd": torch.cat( - [state['scale_inv_fwd'].view(-1, 1) for state in state_list], dim=1 - ).view(-1), - "amax_history_fwd": torch.cat( - [state['amax_history_fwd'].view(-1, 1) for state in state_list], dim=1 - ).view(self.fp8_meta["recipe"].amax_history_len, -1), - "scale_bwd": torch.cat( - [state['scale_bwd'].view(-1, 1) for state in state_list], dim=1 - ).view(-1), - "scale_inv_bwd": torch.cat( - [state['scale_inv_bwd'].view(-1, 1) for state in state_list], dim=1 - ).view(-1), - "amax_history_bwd": torch.cat( - [state['amax_history_bwd'].view(-1, 1) for state in state_list], dim=1 - ).view(self.fp8_meta["recipe"].amax_history_len, -1), - "extra_fp8_variables": extra_fp8_variables, - } - state_dict[f"{prefix}_extra_state"] = self._encode_extra_state(extra_state) - - self._register_load_state_dict_pre_hook(merge_extra_states, with_module=True) - - def forward(self, x, m_splits): - """Forward.""" - _is_first_microbatch = ( - None if self.disable_parameter_transpose_cache else self.is_first_microbatch - ) - out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) - self.is_first_microbatch = False - - # TE only returns a tuple when return_bias is True, otherwise - # it returns a single Tensor, we always want to return two - # values regardless of the arguments. - if self.te_return_bias: - return out - return out, None - - def _encode_extra_state(self, state): - state_serialized = io.BytesIO() - torch.save(state, state_serialized) - return state_serialized - - def _decode_extra_state(self, state): - if isinstance(state, torch.Tensor): - return pickle.loads(state.detach().cpu().numpy().tobytes()) - elif isinstance(state, io.BytesIO): - state.seek(0) - return torch.load(state, map_location="cuda") - else: - raise RuntimeError("Unsupported checkpoint format.") - - def _split_extra_state(self, state): - fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration - - if not fp8_checkpoint: - return [state] * self.num_gemms - - state = self._decode_extra_state(state) - extra_states = [] - extra_fp8_variables = state['extra_fp8_variables'] - extra_fp8_variables['num_gemms'] = 1 - for gemm_idx in range(self.num_gemms): - tmp_state = { - "scale_fwd": state['scale_fwd'].view(3, -1)[:, gemm_idx], - "scale_inv_fwd": state['scale_inv_fwd'].view(3, -1)[:, gemm_idx], - "amax_history_fwd": state['amax_history_fwd'].view( - self.fp8_meta["recipe"].amax_history_len, 3, -1 - )[:, :, gemm_idx], - "scale_bwd": state['scale_bwd'].view(2, -1)[:, gemm_idx], - "scale_inv_bwd": state['scale_inv_bwd'].view(2, -1)[:, gemm_idx], - "amax_history_bwd": state['amax_history_bwd'].view( - self.fp8_meta["recipe"].amax_history_len, 2, -1 - )[:, :, gemm_idx], - "extra_fp8_variables": extra_fp8_variables, - } - extra_states.append(self._encode_extra_state(tmp_state)) - return extra_states - - def _sharded_state_dict_grouped( - self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None - ): - """ - prefix should be module_name to make keys identical to sequetial ones. - """ - sharded_state_dict = {} - full_state_dict = self.state_dict(prefix='', keep_vars=True) - num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms - local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms - ep_axis = len(sharded_offsets) - extra_states = self._split_extra_state(full_state_dict['_extra_state']) - for gemm_idx in range(self.num_gemms): - state_dict = { - f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'], - f'{gemm_idx}._extra_state': extra_states[gemm_idx], - } - if self.use_bias: - state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}'] - sub_sd = make_sharded_tensors_for_checkpoint( - state_dict, - '', - tp_axis_map, - ( - *sharded_offsets, - (ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts), - ), - ) - # Remove expert layers indexing from sharded keys - replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix) - sharded_state_dict.update( - { - f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'], - f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[ - f'{gemm_idx}._extra_state' - ], - } - ) - if self.use_bias: - sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias'] - # Adjust replica ids - replication along DP modulo EP - for k, sh_ten in sharded_state_dict.items(): - replica_id = sh_ten.replica_id - assert ( - len(replica_id) == 3 - ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}' - sh_ten.replica_id = (*replica_id[:2], get_expert_data_parallel_rank()) - return sharded_state_dict - - class TEColumnParallelGroupedLinear(TEGroupedLinear): - """ - Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized - to column-parallel style. - """ - - def __init__( - self, - num_gemms: int, - input_size: int, - output_size: int, - *, - config: ModelParallelConfig, - init_method: Callable, - bias: bool, - skip_bias_add: bool, - is_expert: bool, - tp_comm_buffer_name: str = None, - ): - - super().__init__( - num_gemms=num_gemms, - input_size=input_size, - output_size=output_size, - parallel_mode="column", - config=config, - init_method=condition_init_method(config, init_method), - bias=bias, - skip_bias_add=skip_bias_add, - is_expert=is_expert, - tp_comm_buffer_name=tp_comm_buffer_name, - ) - - def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): - """ - For each gemm, sharding along axis 0, bias sharded. - Assume sharded_offsets[-1] is the expert parallel offset. - """ - tp_axis_map = {} - for gemm_idx in range(self.num_gemms): - tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0}) - return super()._sharded_state_dict_grouped( - tp_axis_map, prefix, sharded_offsets, metadata - ) - - class TERowParallelGroupedLinear(TEGroupedLinear): - """ - Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized - to row-parallel style. - """ - - def __init__( - self, - num_gemms: int, - input_size: int, - output_size: int, - *, - config: ModelParallelConfig, - init_method: Callable, - bias: bool, - skip_bias_add: bool, - is_expert: bool, - tp_comm_buffer_name: str = None, - ): - - super().__init__( - num_gemms=num_gemms, - input_size=input_size, - output_size=output_size, - parallel_mode="row", - config=config, - init_method=condition_init_method(config, init_method), - bias=bias, - skip_bias_add=skip_bias_add, - is_expert=is_expert, - tp_comm_buffer_name=tp_comm_buffer_name, - ) - - def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): - """ - For each gemm, sharding along axis 1, bias not sharded. - Assume sharded_offsets[-1] is the expert parallel offset. - """ - tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)} - return super()._sharded_state_dict_grouped( - tp_axis_map, prefix, sharded_offsets, metadata - ) - -else: - - TEGroupedLinear = None - TEColumnParallelGroupedLinear = None - TERowParallelGroupedLinear = None - - -class TEDelayedScaling(te.common.recipe.DelayedScaling): - """ - Wrapper for the Transformer-Engine's `DelayedScaling` layer. - """ - - def __init__( - self, - config: ModelParallelConfig, - fp8_format: int, - override_linear_precision: tuple = (False, False, False), - ): - extra_kwargs = _get_extra_te_kwargs(config) - if is_te_min_version("1.6.0.dev0"): - extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention - extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention - if get_te_version() < PkgVersion("1.8.0"): - extra_kwargs["interval"] = config.fp8_interval - elif config.fp8_interval != 1: - warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.") - - super().__init__( - margin=config.fp8_margin, - fp8_format=fp8_format, - amax_compute_algo=config.fp8_amax_compute_algo, - amax_history_len=config.fp8_amax_history_len, - override_linear_precision=override_linear_precision, - **extra_kwargs, - ) - - -class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker): - """Wraps TransformerEngine's CudaRNGStatesTracker so that it is - interchangeable with Megatron's RNG tracker""" - - def is_initialized(self): - """Checks if the internal RNG state has been set wirth set_states().""" - return self._is_initialized - - def reset(self): - """Reset the internal RNG state.""" - super().reset() - self._is_initialized = False - - def set_states(self, states): - """Set the internal RNG state.""" - super().set_states(states) - self._is_initialized = True - - def add(self, name, seed): - """Track the rng state.""" - super().add(name, seed) - self._is_initialized = True - - -def te_checkpoint( - forward_func, - distribute_saved_activations, - get_rng_state_tracker, - tp_group, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, -): - """Checkpointing with Transformer-Engine.""" - from transformer_engine.pytorch.distributed import checkpoint - - if is_te_min_version("1.5.0"): - return checkpoint( - forward_func, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - distribute_saved_activations=distribute_saved_activations, - get_rng_state_tracker=get_rng_state_tracker, - tp_group=tp_group, - ) - else: - return checkpoint( - forward_func, - distribute_saved_activations, - get_rng_state_tracker, - tp_group, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - ) - - -try: - - from transformer_engine.pytorch.attention import _SplitAlongDim - - SplitAlongDim = _SplitAlongDim.apply - -except ImportError: - - SplitAlongDim = None - -try: - - from transformer_engine.pytorch.cpu_offload import ( - get_cpu_offload_context as _get_cpu_offload_context, - ) - - def get_cpu_offload_context( - enabled, num_layers, model_layers, activation_offloading, weight_offloading - ): - """Get CPU offload context and sync function.""" - if is_te_min_version("1.10.0.dev0"): - context, sync_func = _get_cpu_offload_context( - enabled, num_layers, model_layers, activation_offloading, weight_offloading - ) - else: - context, sync_func = _get_cpu_offload_context( - enabled, num_layers, activation_offloading, weight_offloading - ) - - return context, sync_func - -except ImportError: - - get_cpu_offload_context = None - -try: - - from transformer_engine.pytorch.attention import FusedRoPEFunc - - def fused_apply_rotary_pos_emb( - t: torch.Tensor, freqs: torch.Tensor, transpose_output_memory: bool = False - ) -> torch.Tensor: - """Apply rotary positional embedding to input tensor T in `sbhd` format.""" - if transpose_output_memory: - warnings.warn( - "transpose_output_memory is not supported by TE's fused RoPE and will be ignored." - ) - return FusedRoPEFunc.apply(t, freqs, "sbhd") - - def fused_apply_rotary_pos_emb_thd( - t: torch.Tensor, - cu_seqlens: torch.Tensor, - freqs: torch.Tensor, - cp_size: int = 1, - cp_rank: int = 0, - ) -> torch.Tensor: - """ - Apply rotary positional embedding to input tensor T in `thd` format with CP support. - """ - if is_te_min_version("1.11.0", check_equality=False): - return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens, cp_size, cp_rank) - else: - return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens) - -except ImportError: - - pass - -try: - - from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding # pylint: disable=unused-import - -except ImportError: - - Fp8Padding = None - Fp8Unpadding = None +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import dataclasses +import io +import os +import pickle +import warnings +from typing import Any, Callable, Optional + +import torch +import transformer_engine as te +from packaging.version import Version as PkgVersion +from torch import Tensor +from torch.nn.parameter import Parameter + +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_context_parallel_global_ranks, + get_context_parallel_group, + get_expert_data_parallel_rank, + get_expert_model_parallel_rank, + get_expert_model_parallel_world_size, + get_expert_tensor_parallel_group, + get_expert_tensor_parallel_rank, + get_expert_tensor_parallel_world_size, + get_hierarchical_context_parallel_groups, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name +from megatron.core.tensor_parallel.layers import ( + _initialize_affine_weight_cpu, + set_tensor_model_parallel_attributes, +) +from megatron.core.tensor_parallel.random import get_data_parallel_rng_tracker_name +from megatron.core.tensor_parallel.utils import divide +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint +from megatron.core.utils import get_te_version, is_te_min_version + + +def _get_extra_te_kwargs(config: TransformerConfig): + extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype} + + if is_te_min_version("0.12.0"): + if config.use_cpu_initialization: + extra_transformer_engine_kwargs["device"] = 'cpu' + elif config.init_model_with_meta_device: + extra_transformer_engine_kwargs["device"] = "meta" + else: + extra_transformer_engine_kwargs["device"] = torch.cuda.current_device() + return extra_transformer_engine_kwargs + + +def condition_init_method(config, init_method): + """Condition TE init_method on config.perform_initialization.""" + return init_method if config.perform_initialization else (lambda w: None) + + +class TENorm: + """ + A conditional wrapper to initialize an instance of Transformer-Engine's + `LayerNorm` or `RMSNorm` based on input + """ + + # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? + def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): + if config.normalization == "LayerNorm": + instance = te.pytorch.LayerNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) + elif config.normalization == "RMSNorm": + assert hasattr( + te.pytorch, "RMSNorm" + ), "Transformer-Engine >= v0.11 required to use this feature" + instance = te.pytorch.RMSNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) + else: + raise Exception('Only LayerNorm and RMSNorm are curently supported') + + return instance + + +class TELinear(te.pytorch.Linear): + """ + Wrapper for the Transformer-Engine's `Linear` layer. + + Note that if Megatron's parallel_state has not been initialized + yet, the tp_group passed to TE will be None and must be set later + via set_tensor_parallel_group(). + + parallel_mode currently supports 3 different values: + - "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear) + - "row": Split the weight matrix along input dimension (used in TERowParallelLinear) + - "duplicated": No tensor parallelism and weight is duplicated across TP ranks + - Note: For expert linear layers, we will disable communication logic here + as TP communication is handled in token_dispatcher. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + parallel_mode: Optional[str], + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + skip_weight_param_allocation: bool, + tp_comm_buffer_name: Optional[str] = None, + is_expert: bool = False, + ): + self.config = config + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + if skip_weight_param_allocation: + raise ValueError( + 'Transformer Engine linear layers do not support skip_weight_param_allocation' + ) + + extra_kwargs = _get_extra_te_kwargs(config) + + if is_te_min_version("0.8.0"): + if self.config.tp_comm_overlap: + if is_te_min_version("1.5.0"): + # Use old overlap flags if they were supplied instead + extra_kwargs["ub_overlap_ag"] = ( + self.config.tp_comm_overlap_ag + if hasattr(self.config, "tp_comm_overlap_ag") + else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag + ) + extra_kwargs["ub_overlap_rs"] = ( + self.config.tp_comm_overlap_rs + if hasattr(self.config, "tp_comm_overlap_rs") + else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs + ) + # Disable ub overlap for experts. + if is_expert: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs"] = False + else: + extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag + extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag + extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs + extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs + # Disable ub overlap for experts. + if is_expert: + extra_kwargs["ub_split_ag"] = False + extra_kwargs["ub_atomic_gemm_ag"] = False + extra_kwargs["ub_split_rs"] = False + extra_kwargs["ub_atomic_gemm_rs"] = False + if is_te_min_version("1.0.0", check_equality=False): + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + self.expert_parallel = self.config.expert_model_parallel_size > 1 + if is_expert: + rng_tracker_name = get_expert_parallel_rng_tracker_name() + else: + if parallel_mode == "duplicated": + rng_tracker_name = get_data_parallel_rng_tracker_name() + else: + rng_tracker_name = None + if is_te_min_version("1.7.0"): + extra_kwargs["rng_tracker_name"] = rng_tracker_name + + te_parallel_mode = parallel_mode + if parallel_mode == "duplicated": + # Handle non-parallel case + tp_group = None + tp_size = 1 + explicit_expert_comm = False + te_parallel_mode = None + else: + # Disable communications in TE when using TP or EP by + # making TE agnostic of model parallel. + if is_expert: + tp_group = get_expert_tensor_parallel_group(check_initialized=False) + tp_size = get_expert_tensor_parallel_world_size() + else: + tp_group = get_tensor_model_parallel_group(check_initialized=False) + tp_size = get_tensor_model_parallel_world_size() + explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + + if explicit_expert_comm: + if parallel_mode == "column": + output_size = divide(output_size, tp_size) + elif parallel_mode == "row": + input_size = divide(input_size, tp_size) + te_parallel_mode = None + tp_size = 1 + tp_group = None + + super().__init__( + in_features=input_size, + out_features=output_size, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=tp_group, + tp_size=tp_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=condition_init_method(config, init_method), + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=te_parallel_mode, + **extra_kwargs, + ) + + for param in self.parameters(): + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, 'allreduce', not self.expert_parallel) + else: + # Reduce the gradient on DP group + setattr(param, 'allreduce', True) + if parallel_mode == "duplicated": + # Reduce the gradient further on the TP group since the weight is + # duplicated across TP ranks + setattr(param, 'sequence_parallel', self.config.sequence_parallel) + + def forward(self, x): + """Forward.""" + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Replicate cross TP/DP.""" + + # Provide the dist-ckpt support when TELinear is directly used + # It can only happen with duplicated parallel mode + assert ( + self.parallel_mode == None + ), "TELinear sharded_state_dict can only be used with duplicated parallel mode" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint(state_dict, prefix, None, sharded_offsets) + + +class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): + """ + Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines + layernorm and linear layers + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: TransformerConfig, + init_method: Callable, + gather_output: bool, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + skip_weight_param_allocation: bool = False, + tp_comm_buffer_name: Optional[str] = None, + ): + self.config = config + + if gather_output: + raise ValueError('Transformer Engine linear layers do not support gather_output = True') + + if is_expert: + raise ValueError('Transformer Engine linear layers do not yet support MoE') + + if skip_weight_param_allocation: + raise ValueError( + 'Transformer Engine linear layers do not support skip_weight_param_allocation' + ) + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + extra_kwargs = _get_extra_te_kwargs(config) + + # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` + if is_te_min_version("0.11.0"): + extra_kwargs["normalization"] = self.config.normalization + elif self.config.normalization != "LayerNorm": + te_version = get_te_version() + raise ValueError( + f"Transformer Engine v{te_version} does not support {self.config.normalization}." + ) + + if is_te_min_version("0.8.0"): + if self.config.tp_comm_overlap: + extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad + extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad + if is_te_min_version("1.5.0", check_equality=False): + # Use old overlap flags if they were supplied instead + extra_kwargs["ub_overlap_ag"] = ( + self.config.tp_comm_overlap_ag + if hasattr(self.config, "tp_comm_overlap_ag") + else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag + ) + if is_te_min_version("1.6.0.dev0", check_equality=False): + extra_kwargs["ub_overlap_rs_dgrad"] = ( + self.config.tp_comm_overlap_rs_dgrad + if hasattr(self.config, "tp_comm_overlap_rs_dgrad") + else False + ) + if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs_dgrad"] = False + + if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs_dgrad"] = False + else: + extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag + extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag + if is_te_min_version("1.0.0", check_equality=False): + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + super().__init__( + in_features=input_size, + out_features=output_size, + eps=self.config.layernorm_epsilon, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=get_tensor_model_parallel_group(check_initialized=False), + tp_size=self.config.tensor_model_parallel_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=( + condition_init_method(config, init_method) + if not config.use_cpu_initialization + else lambda w: None + ), + bias=bias, + return_bias=self.te_return_bias, + parallel_mode="column", + return_layernorm_output=False, + zero_centered_gamma=self.config.layernorm_zero_centered_gamma, + **extra_kwargs, + ) + + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + + if config.use_cpu_initialization: + output_size_per_partition = divide(output_size, world_size) + _ = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + output_size_per_partition, + 0, + init_method=condition_init_method(config, init_method), + stride=1, + return_master_weight=False, + rank=rank, + world_size=world_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter( + torch.empty(output_size_per_partition, dtype=config.params_dtype) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', True) + + def forward(self, x): + """Forward.""" + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets + ) + + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + + +class TEColumnParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `ColumnParallelLinear` layer. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + gather_output: bool, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + skip_weight_param_allocation: bool = False, + tp_comm_buffer_name: Optional[str] = None, + ): + if gather_output: + raise ValueError('Transformer Engine linear layers do not support gather_output = True') + + super().__init__( + input_size=input_size, + output_size=output_size, + parallel_mode="column", + config=config, + init_method=( + condition_init_method(config, init_method) + if not config.use_cpu_initialization + else lambda w: None + ), + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + skip_weight_param_allocation=skip_weight_param_allocation, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + if config.use_cpu_initialization: + if is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() + else: + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + output_size_per_partition = divide(output_size, world_size) + _ = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + output_size_per_partition, + 0, + init_method=condition_init_method(config, init_method), + stride=1, + return_master_weight=False, + rank=rank, + world_size=world_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter( + torch.empty(output_size_per_partition, dtype=config.params_dtype) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', True) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets + ) + + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + + +class TERowParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `RowParallelLinear` layer. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + input_is_parallel: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: Optional[str] = None, + ): + if not input_is_parallel: + raise ValueError( + "Transformer Engine linear layers do not support input_is_parallel = False" + ) + + super().__init__( + input_size=input_size, + output_size=output_size, + parallel_mode="row", + config=config, + init_method=( + condition_init_method(config, init_method) + if not config.use_cpu_initialization + else lambda w: None + ), + bias=bias, + skip_bias_add=skip_bias_add, + skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + if config.use_cpu_initialization: + if is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() + else: + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + input_size_per_partition = divide(input_size, world_size) + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + input_size_per_partition, + 1, + init_method=condition_init_method(config, init_method), + stride=1, + return_master_weight=False, + params_dtype=config.params_dtype, + rank=rank, + world_size=world_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter(torch.empty(output_size, dtype=config.params_dtype)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', True) + setattr(self.bias, 'sequence_parallel', config.sequence_parallel) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 1, bias not sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 1}, sharded_offsets + ) + + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + + +class TEDotProductAttention(te.pytorch.DotProductAttention): + """ + Wrapper for the Transformer-Engine's `DotProductAttention` layer that also + has "flash attention" enabled. + + Note that if Megatron's parallel_state has not been initialized yet, the + tp_group and cp_group passed to TE will be None and must be set later + via set_tensor_parallel_group() and set_context_parallel_group(). + """ + + cp_stream: torch.cuda.Stream = None + + def __init__( + self, + config: TransformerConfig, + layer_number: int, + attn_mask_type: AttnMaskType, + attention_type: str, + attention_dropout: Optional[float] = None, + softmax_scale: Optional[float] = None, + k_channels: Optional[int] = None, + v_channels: Optional[int] = None, + cp_comm_type: str = "p2p", + ): + self.config = config + self.te_forward_mask_type = False + self.qkv_format: str = 'sbhd' + + if self.config.apply_query_key_layer_scaling != bool( + int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0')) + ): + raise ValueError( + f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} " + f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is " + f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support " + f"setting query key layer scaling via argument, so these two must match." + ) + + extra_kwargs: dict[str, Any] = {} + if is_te_min_version("0.11.0"): + extra_kwargs["num_gqa_groups"] = self.config.num_query_groups + elif self.config.num_query_groups != self.config.num_attention_heads: + raise ValueError( + f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, " + f"use a newer version of Transformer Engine. " + f"(num_query_groups ({self.config.num_query_groups}) != " + f"num_attention_heads ({self.config.num_attention_heads}))" + ) + + if is_te_min_version("0.10.0"): + extra_kwargs["attention_type"] = attention_type + # older version don't need attention_type + + if is_te_min_version("0.12.0", check_equality=False): + self.te_forward_mask_type = True + + # This check is important as CP config can be disabled while having a valid CP group + # Example - Disabling CP for encoder while a valid CP group exists for decoder + if self.config.context_parallel_size > 1: + assert is_te_min_version( + "1.0.0" + ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!" + if getattr(TEDotProductAttention, "cp_stream") is None: + TEDotProductAttention.cp_stream = torch.cuda.Stream() + extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) + extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks( + check_initialized=False + ) + extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream + if is_te_min_version("1.10.0"): + if cp_comm_type is None: + extra_kwargs["cp_comm_type"] = "p2p" + elif cp_comm_type == "a2a+p2p": + assert is_te_min_version("1.12.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support" + "hierarchical cp commucation." + ) + extra_kwargs["cp_comm_type"] = "a2a+p2p" + extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups( + check_initialized=False + ) + else: + extra_kwargs["cp_comm_type"] = cp_comm_type + + if self.config.deterministic_mode: + if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0: + raise RuntimeError( + "deterministic_mode is on and we are using DotProductAttention from " + "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. " + f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}." + ) + + if config.window_size is not None: + # Check version + assert is_te_min_version("1.2.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" + "sliding window attention." + ) + extra_kwargs['window_size'] = config.window_size + + if is_te_min_version("1.10.0"): + # TE 1.10.0 introduces the ability to set the different k and v channels + kv_channels = ( + (k_channels, v_channels) + if k_channels is not None and v_channels is not None + else self.config.kv_channels + ) + extra_kwargs['softmax_scale'] = softmax_scale + else: + kv_channels = self.config.kv_channels + + self.kept_packed_seq_params = set( + field.name for field in dataclasses.fields(PackedSeqParams) + ) + if get_te_version() < PkgVersion("1.3.0"): + # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H + # copies (#555) + # These two arguments did not exist prior to 1.3.0 + self.kept_packed_seq_params.discard("max_seqlen_q") + self.kept_packed_seq_params.discard("max_seqlen_kv") + + if get_te_version() < PkgVersion("1.10.0"): + # TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted + # in each individual sequence in THD format dataset + # These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012) + self.kept_packed_seq_params.discard("cu_seqlens_q_padded") + self.kept_packed_seq_params.discard("cu_seqlens_kv_padded") + + super().__init__( + num_attention_heads=self.config.num_attention_heads, + kv_channels=kv_channels, + attention_dropout=( + self.config.attention_dropout if attention_dropout is None else attention_dropout + ), + attn_mask_type=attn_mask_type.name, + sequence_parallel=self.config.sequence_parallel, + tp_size=self.config.tensor_model_parallel_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + tp_group=get_tensor_model_parallel_group(check_initialized=False), + layer_number=layer_number, + **extra_kwargs, + ) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask: Tensor, + attn_mask_type: AttnMaskType, + attention_bias: Tensor = None, + packed_seq_params: PackedSeqParams = None, + ): + """Forward.""" + packed_seq_kwargs = ( + {key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params} + if packed_seq_params is not None + else {} + ) + # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set + # after init + if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False): + self.qkv_format = 'bshd' + + qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format) + + # WAR for peak memory usage. + # See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388 + if self.config.apply_rope_fusion and qkv_format == 'bshd': + query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)] + # In PyTorch, the following two tensors are in fact the same: + # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1) + # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1) + # Stride for a dimension that is 1 has no meaning, so tensors created two different ways + # can have same shape but different strides. + # We unify them to the first one to pass the stride check in TE + if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride(): + value = value.as_strided(value.shape, key.stride()) + + attention_bias_kwargs = {} + if attention_bias is not None: + assert is_te_min_version("1.2.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" + "`attention_bias`." + ) + attention_bias_kwargs = dict( + core_attention_bias_type='post_scale_bias', core_attention_bias=attention_bias + ) + + if self.te_forward_mask_type: + if qkv_format == 'thd' and is_te_min_version("1.7.0"): + # thd format uses flash attention with cuDNN kernel which requires is_padding=True, + # so the only acceptable mask types are `padding_causal` and `padding`. These do not + # necessarily indicate there are padded tokens in the sequence. + if attn_mask_type == AttnMaskType.causal: + attn_mask_type = AttnMaskType.padding_causal + elif attn_mask_type == AttnMaskType.no_mask: + attn_mask_type = AttnMaskType.padding + core_attn_out = super().forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type.name, + **attention_bias_kwargs, + **packed_seq_kwargs, + ) + else: + core_attn_out = super().forward( + query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs + ) + + if self.config.apply_rope_fusion and qkv_format == 'bshd': + return core_attn_out.transpose(0, 1) + else: + return core_attn_out + + +if is_te_min_version("1.9.0.dev0"): + + class TEGroupedLinear(te.pytorch.BatchLinear if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')) else te.pytorch.GroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer. + + Note that if Megatron's parallel_state has not been initialized + yet, the tp_group passed to TE will be None and must be set later + via set_tensor_parallel_group(). + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + *, + parallel_mode: Optional[str], + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + is_expert: bool = False, + tp_comm_buffer_name: Optional[str] = None, + ): + self.config = config + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + + extra_kwargs = _get_extra_te_kwargs(config) + extra_kwargs["ub_name"] = tp_comm_buffer_name + + self.expert_parallel = self.config.expert_model_parallel_size > 1 + if is_expert: + extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name() + + # The comms between TP and EP group is explicitly handled by MoE token dispatcher. + # So we disable comms by making TE agnostic of model parallel. + if is_expert: + tp_group = get_expert_tensor_parallel_group(check_initialized=False) + tp_size = get_expert_tensor_parallel_world_size() + else: + tp_group = get_tensor_model_parallel_group(check_initialized=False) + tp_size = get_tensor_model_parallel_world_size() + self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + + if self.explicit_expert_comm: + if parallel_mode == "column": + output_size = divide(output_size, tp_size) + elif parallel_mode == "row": + input_size = divide(input_size, tp_size) + parallel_mode = None + tp_size = 1 + tp_group = None + + super().__init__( + num_gemms=num_gemms, + in_features=input_size, + out_features=output_size, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=tp_group, + tp_size=tp_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=condition_init_method(config, init_method), + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=parallel_mode, + **extra_kwargs, + ) + + for param in self.parameters(): + setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) + + def merge_extra_states( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """ + Merge multiple "_extra_state" into one. + """ + self.init_fp8_metadata(num_gemms=self.num_gemms) + fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration + + try: + state_list = [ + state_dict.pop(f"{prefix}_extra_state{i}") for i in range(1, self.num_gemms) + ] + except KeyError: + # "_extra_state{i}" only exists for dist-ckpt. Return for torch native ckpt. + return + + if not fp8_checkpoint: + return + state_list = [state_dict.pop(f"{prefix}_extra_state")] + state_list + state_list = [self._decode_extra_state(state) for state in state_list] + extra_fp8_variables = state_list[0]['extra_fp8_variables'] + extra_fp8_variables['num_gemms'] = self.num_gemms + extra_state = { + "scale_fwd": torch.cat( + [state['scale_fwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "scale_inv_fwd": torch.cat( + [state['scale_inv_fwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "amax_history_fwd": torch.cat( + [state['amax_history_fwd'].view(-1, 1) for state in state_list], dim=1 + ).view(self.fp8_meta["recipe"].amax_history_len, -1), + "scale_bwd": torch.cat( + [state['scale_bwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "scale_inv_bwd": torch.cat( + [state['scale_inv_bwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "amax_history_bwd": torch.cat( + [state['amax_history_bwd'].view(-1, 1) for state in state_list], dim=1 + ).view(self.fp8_meta["recipe"].amax_history_len, -1), + "extra_fp8_variables": extra_fp8_variables, + } + state_dict[f"{prefix}_extra_state"] = self._encode_extra_state(extra_state) + + self._register_load_state_dict_pre_hook(merge_extra_states, with_module=True) + + def forward(self, x, m_splits): + """Forward.""" + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + def _encode_extra_state(self, state): + state_serialized = io.BytesIO() + torch.save(state, state_serialized) + return state_serialized + + def _decode_extra_state(self, state): + if isinstance(state, torch.Tensor): + return pickle.loads(state.detach().cpu().numpy().tobytes()) + elif isinstance(state, io.BytesIO): + state.seek(0) + return torch.load(state, map_location="cuda") + else: + raise RuntimeError("Unsupported checkpoint format.") + + def _split_extra_state(self, state): + fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration + + if not fp8_checkpoint: + return [state] * self.num_gemms + + state = self._decode_extra_state(state) + extra_states = [] + extra_fp8_variables = state['extra_fp8_variables'] + extra_fp8_variables['num_gemms'] = 1 + for gemm_idx in range(self.num_gemms): + tmp_state = { + "scale_fwd": state['scale_fwd'].view(3, -1)[:, gemm_idx], + "scale_inv_fwd": state['scale_inv_fwd'].view(3, -1)[:, gemm_idx], + "amax_history_fwd": state['amax_history_fwd'].view( + self.fp8_meta["recipe"].amax_history_len, 3, -1 + )[:, :, gemm_idx], + "scale_bwd": state['scale_bwd'].view(2, -1)[:, gemm_idx], + "scale_inv_bwd": state['scale_inv_bwd'].view(2, -1)[:, gemm_idx], + "amax_history_bwd": state['amax_history_bwd'].view( + self.fp8_meta["recipe"].amax_history_len, 2, -1 + )[:, :, gemm_idx], + "extra_fp8_variables": extra_fp8_variables, + } + extra_states.append(self._encode_extra_state(tmp_state)) + return extra_states + + def _sharded_state_dict_grouped( + self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None + ): + """ + prefix should be module_name to make keys identical to sequetial ones. + """ + sharded_state_dict = {} + full_state_dict = self.state_dict(prefix='', keep_vars=True) + num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms + local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms + ep_axis = len(sharded_offsets) + extra_states = self._split_extra_state(full_state_dict['_extra_state']) + for gemm_idx in range(self.num_gemms): + state_dict = { + f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'], + f'{gemm_idx}._extra_state': extra_states[gemm_idx], + } + if self.use_bias: + state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}'] + sub_sd = make_sharded_tensors_for_checkpoint( + state_dict, + '', + tp_axis_map, + ( + *sharded_offsets, + (ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts), + ), + ) + # Remove expert layers indexing from sharded keys + replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix) + sharded_state_dict.update( + { + f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'], + f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[ + f'{gemm_idx}._extra_state' + ], + } + ) + if self.use_bias: + sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias'] + # Adjust replica ids - replication along DP modulo EP + for k, sh_ten in sharded_state_dict.items(): + replica_id = sh_ten.replica_id + assert ( + len(replica_id) == 3 + ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}' + if getattr(sh_ten, "is_data_parallel_fully_shard", False): + edp_replica_id = 0 + else: + edp_replica_id = get_expert_data_parallel_rank() + sh_ten.replica_id = (*replica_id[:2], edp_replica_id) + return sharded_state_dict + + class TEColumnParallelGroupedLinear(TEGroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized + to column-parallel style. + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: Optional[str] = None, + ): + + super().__init__( + num_gemms=num_gemms, + input_size=input_size, + output_size=output_size, + parallel_mode="column", + config=config, + init_method=condition_init_method(config, init_method), + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + For each gemm, sharding along axis 0, bias sharded. + Assume sharded_offsets[-1] is the expert parallel offset. + """ + tp_axis_map = {} + for gemm_idx in range(self.num_gemms): + tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0}) + return super()._sharded_state_dict_grouped( + tp_axis_map, prefix, sharded_offsets, metadata + ) + + class TERowParallelGroupedLinear(TEGroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized + to row-parallel style. + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: Optional[str] = None, + ): + + super().__init__( + num_gemms=num_gemms, + input_size=input_size, + output_size=output_size, + parallel_mode="row", + config=config, + init_method=condition_init_method(config, init_method), + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + For each gemm, sharding along axis 1, bias not sharded. + Assume sharded_offsets[-1] is the expert parallel offset. + """ + tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)} + return super()._sharded_state_dict_grouped( + tp_axis_map, prefix, sharded_offsets, metadata + ) + +else: + + TEGroupedLinear = None # type: ignore[assignment, misc] + TEColumnParallelGroupedLinear = None # type: ignore[assignment, misc] + TERowParallelGroupedLinear = None # type: ignore[assignment, misc] + + +class TEDelayedScaling(te.common.recipe.DelayedScaling): + """ + Wrapper for the Transformer-Engine's `DelayedScaling` layer. + """ + + def __init__( + self, + config: ModelParallelConfig, + fp8_format: int, + override_linear_precision: tuple = (False, False, False), + ): + extra_kwargs = _get_extra_te_kwargs(config) + if is_te_min_version("1.6.0.dev0"): + extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention + extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention + if get_te_version() < PkgVersion("1.8.0"): + extra_kwargs["interval"] = config.fp8_interval + elif config.fp8_interval != 1: + warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.") + + super().__init__( + margin=config.fp8_margin, + fp8_format=fp8_format, + amax_compute_algo=config.fp8_amax_compute_algo, + amax_history_len=config.fp8_amax_history_len, + override_linear_precision=override_linear_precision, + **extra_kwargs, + ) + + +class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker): + """Wraps TransformerEngine's CudaRNGStatesTracker so that it is + interchangeable with Megatron's RNG tracker""" + + def __init__(self): + super().__init__() + self.reset() + + def is_initialized(self): + """Checks if the internal RNG state has been set wirth set_states().""" + return self._is_initialized + + def reset(self): + """Reset the internal RNG state.""" + super().reset() + self._is_initialized = False + + def set_states(self, states): + """Set the internal RNG state.""" + super().set_states(states) + self._is_initialized = True + + def add(self, name, seed): + """Track the rng state.""" + super().add(name, seed) + self._is_initialized = True + + +def te_checkpoint( + forward_func, + distribute_saved_activations, + get_rng_state_tracker, + tp_group, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, +): + """Checkpointing with Transformer-Engine.""" + from transformer_engine.pytorch.distributed import checkpoint + + if is_te_min_version("1.5.0"): + return checkpoint( + forward_func, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + distribute_saved_activations=distribute_saved_activations, + get_rng_state_tracker=get_rng_state_tracker, + tp_group=tp_group, + ) + else: + return checkpoint( + forward_func, + distribute_saved_activations, + get_rng_state_tracker, + tp_group, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + +try: + + from transformer_engine.pytorch.attention import _SplitAlongDim + + SplitAlongDim = _SplitAlongDim.apply + +except ImportError: + + SplitAlongDim = None + +try: + + from transformer_engine.pytorch.cpu_offload import ( + get_cpu_offload_context as _get_cpu_offload_context, + ) + + def get_cpu_offload_context( + enabled, num_layers, model_layers, activation_offloading, weight_offloading + ): + """Get CPU offload context and sync function.""" + if is_te_min_version("1.10.0.dev0"): + context, sync_func = _get_cpu_offload_context( + enabled, num_layers, model_layers, activation_offloading, weight_offloading + ) + else: + context, sync_func = _get_cpu_offload_context( + enabled, num_layers, activation_offloading, weight_offloading + ) + + return context, sync_func + +except ImportError: + + get_cpu_offload_context = None # type: ignore[assignment, misc] + +try: + + from transformer_engine.pytorch.attention import FusedRoPEFunc + + def fused_apply_rotary_pos_emb( + t: torch.Tensor, freqs: torch.Tensor, transpose_output_memory: bool = False + ) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T in `sbhd` format.""" + if transpose_output_memory: + warnings.warn( + "transpose_output_memory is not supported by TE's fused RoPE and will be ignored." + ) + return FusedRoPEFunc.apply(t, freqs, "sbhd") + + def fused_apply_rotary_pos_emb_thd( + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, + cp_size: int = 1, + cp_rank: int = 0, + ) -> torch.Tensor: + """ + Apply rotary positional embedding to input tensor T in `thd` format with CP support. + """ + if is_te_min_version("1.11.0", check_equality=False): + return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens, cp_size, cp_rank) + else: + return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens) + +except ImportError: + + pass + +try: + + from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding # pylint: disable=unused-import + +except ImportError: + + Fp8Padding = None + Fp8Unpadding = None + +try: + + from transformer_engine.pytorch.permutation import ( + moe_permute, + moe_sort_chunks_by_index, + moe_unpermute, + ) + + fused_permute = moe_permute + fused_unpermute = moe_unpermute + fused_sort_chunks_by_index = moe_sort_chunks_by_index + +except ImportError: + + fused_permute = None + fused_unpermute = None + fused_sort_chunks_by_index = None diff --git a/megatron/core/inference/ammo_support/__init__.py b/megatron/core/inference/ammo_support/__init__.py deleted file mode 100644 index 12be50cefe87ef9630f97d8b27c6590ae0fc4d06..0000000000000000000000000000000000000000 --- a/megatron/core/inference/ammo_support/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import warnings - -warnings.warn( - "The 'megatron.core.inference.ammo_support' module is deprecated and will be removed in a future release. " - "Please use megatron.core.inference.modelopt_support instead", - DeprecationWarning, -) diff --git a/megatron/core/inference/ammo_support/gpt/model_specs.py b/megatron/core/inference/ammo_support/gpt/model_specs.py deleted file mode 100644 index ba3bd9fa0fb6508300fca8f3fea52f105556bd16..0000000000000000000000000000000000000000 --- a/megatron/core/inference/ammo_support/gpt/model_specs.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from megatron.core.inference.modelopt_support.gpt.model_specs import get_gpt_layer_modelopt_spec diff --git a/megatron/core/inference/ammo_support/gpt/state_dict_hooks.py b/megatron/core/inference/ammo_support/gpt/state_dict_hooks.py deleted file mode 100644 index 8532366222b60a4620ba5b0d7f59e30bb8b698c7..0000000000000000000000000000000000000000 --- a/megatron/core/inference/ammo_support/gpt/state_dict_hooks.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from megatron.core.inference.modelopt_support.gpt.state_dict_hooks import ( - mcore_gpt_load_legacy_state_dict_pre_hook, - mcore_gpt_load_te_state_dict_pre_hook, -) diff --git a/megatron/core/inference/async_stream.py b/megatron/core/inference/async_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..c742dcb0cc7d54222f525e2df55d18c0209ea5b6 --- /dev/null +++ b/megatron/core/inference/async_stream.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright 2025 The vLLM authors. +# +# This code was adopted from https://github.com/vllm-project/vllm/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +from typing import Any, AsyncGenerator, Callable, Optional, Type, Union + +from megatron.core.inference.inference_request import InferenceRequest + +STOP_ITERATION = Exception() + + +class AsyncStream: + """ + Class for encapsulating an asynchronous stream of InferenceRequest outputs. + + Adopted from https://github.com/vllm-project/vllm/blob/eb881ed006ca458b052905e33f0d16dbb428063a/vllm/v1/engine/async_stream.py # pylint: disable=line-too-long + """ + + def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: + self._request_id = request_id + self._cancel = cancel + self._queue: asyncio.Queue = asyncio.Queue() + self._finished = False + self._loop = asyncio.get_running_loop() + + def put(self, item: Union[InferenceRequest, Exception]) -> None: + """Adds a new value to the stream""" + if not self._finished: + self._loop.call_soon_threadsafe(self._queue.put_nowait, item) + + def finish(self, exception: Optional[Union[BaseException, Type[BaseException]]] = None) -> None: + """Completes the stream by adding a sentinel value""" + if not self._finished: + self._finished = True + self._loop.call_soon_threadsafe( + self._queue.put_nowait, + exception if self._is_raisable(exception) else STOP_ITERATION, + ) + + @property + def finished(self) -> bool: + """Whether the stream has finished""" + return self._finished + + async def generator(self) -> AsyncGenerator[InferenceRequest, None]: + """Creates an AsyncGenerator over the stream queue""" + try: + while True: + result = await self._queue.get() + if self._is_raisable(result): + if result == STOP_ITERATION: + return + raise result + yield result + except GeneratorExit: + self._cancel() + raise asyncio.CancelledError from None + + @staticmethod + def _is_raisable(value: Any): + return isinstance(value, BaseException) or ( + isinstance(value, type) and issubclass(value, BaseException) + ) diff --git a/megatron/core/inference/communication_utils.py b/megatron/core/inference/communication_utils.py index 0c23a583de8eb1242ecf9c63b69fa7d2fd6b8435..53d3eb483368e80443a1ce1935aefc949f2323b8 100644 --- a/megatron/core/inference/communication_utils.py +++ b/megatron/core/inference/communication_utils.py @@ -1,50 +1,54 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import torch - -from megatron.core import parallel_state - - -def _is_cuda(tensor): - """Check if a tensor is not none and is cuda.""" - assert tensor is not None - assert tensor.is_cuda - - -def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): - """Broadcast a tensor from last pipeline stage to all ranks.""" - - if parallel_state.is_pipeline_last_stage(): - _is_cuda(tensor) - assert tensor.is_contiguous() - else: - tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) - # Get the group and corresponding source rank. - src = parallel_state.get_pipeline_model_parallel_last_rank() - group = parallel_state.get_pipeline_model_parallel_group() - torch.distributed.broadcast(tensor, src, group) - return tensor - - -def recv_from_prev_pipeline_rank_(recv_buffer=None): - """Receive from previous pipeline stage and update the - input buffer inplace.""" - recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_buffer, parallel_state.get_pipeline_model_parallel_prev_rank() - ) - reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) - for req in reqs: - req.wait() - # To protect against race condition when using batch_isend_irecv(). - torch.cuda.synchronize() - - -def send_to_next_pipeline_rank(tensor=None): - """Send output to the next pipeline stage.""" - send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, parallel_state.get_pipeline_model_parallel_next_rank() - ) - reqs = torch.distributed.batch_isend_irecv([send_next_op]) - for req in reqs: - req.wait() - # To protect against race condition when using batch_isend_irecv(). - torch.cuda.synchronize() +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch + +from megatron.core import parallel_state + + +def _is_cuda(tensor): + """Check if a tensor is not none and is cuda.""" + assert tensor is not None + assert tensor.is_cuda + + +def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): + """Broadcast a tensor from last pipeline stage to all ranks.""" + + if parallel_state.is_pipeline_last_stage(): + assert size == list( + tensor.shape + ), f"Expected tensor of shape {size} but got {list(tensor.shape)}" + assert dtype == tensor.dtype, f"Expected tensor of type {dtype} but got {tensor.dtype}" + _is_cuda(tensor) + assert tensor.is_contiguous() + else: + tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) + # Get the group and corresponding source rank. + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_pipeline_model_parallel_group() + torch.distributed.broadcast(tensor, src, group) + return tensor + + +def recv_from_prev_pipeline_rank_(recv_buffer=None): + """Receive from previous pipeline stage and update the + input buffer inplace.""" + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_buffer, parallel_state.get_pipeline_model_parallel_prev_rank() + ) + reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + +def send_to_next_pipeline_rank(tensor=None): + """Send output to the next pipeline stage.""" + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, parallel_state.get_pipeline_model_parallel_next_rank() + ) + reqs = torch.distributed.batch_isend_irecv([send_next_op]) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() diff --git a/megatron/core/inference/engines/mcore_engine.py b/megatron/core/inference/engines/mcore_engine.py index 28ef46bf9292999ee68028d98c41ff8e8aa74f43..d080b3fee91271dd7f4ac8fad999d2e4c4823e55 100644 --- a/megatron/core/inference/engines/mcore_engine.py +++ b/megatron/core/inference/engines/mcore_engine.py @@ -1,120 +1,228 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from typing import Dict, List - -import torch - -from megatron.core.inference.engines.abstract_engine import AbstractEngine -from megatron.core.inference.inference_request import InferenceRequest -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.scheduler import Scheduler -from megatron.core.inference.text_generation_controllers.text_generation_controller import ( - TextGenerationController, -) - - -class MCoreEngine(AbstractEngine): - """The Megatron core backend constructor - - This is the backend that does a simple forward pass on the model. - Supports any model that is callable (Accepts the inputs and outputs the tensor) - - Args: - text_generation_controller (TextGenerationController): A text generation - controller that will be used to define how to preprocess prompts, generate - outputs and detokenizer the output tokens. - max_batch_size : The maxinum number of requests to process at once - random_seed (int, optional): Use a random seed if you want deterministic - results. Defaults to None. - """ - - def __init__( - self, - text_generation_controller: TextGenerationController, - max_batch_size, - random_seed: int = None, - ): - self.text_generation_controller = text_generation_controller - self.random_seed = random_seed - self.scheduler = Scheduler(max_batch_size=max_batch_size) - - def generate( - self, - prompts: List[str], - add_BOS: bool = False, - encoder_prompts: List[str] = None, - common_inference_params: SamplingParams = None, - sampling_params: SamplingParams = None, - ) -> dict: - """The megatron core inference backend generate function - - This backend returns the output generations as a dictionary. - It returns the prompt tokens along with the generated tokens, the prompt - plus the generated string and the output log probabilities if requested - - Args: - prompts (List[str]): All the prompts as a list of strings - add_BOS (bool): Whether to add BOS token to beginning of prompts - encoder_prompts (List[dict]): All the encoder prompts as a list of strings - common_inference_params: Deprecated. Only used for backward compatibility with - MCore <= 0.9.0. Use `sampling_params` going forward. - sampling_params (SamplingParams): The request-level sampling parameters - - Returns: - List[InferenceRequest]: The output is list of inference requests containing the - generated tokens, texts and log probs if required - """ - # TODO :M core- get rng state tracker - - if common_inference_params: - sampling_params = common_inference_params - - if self.random_seed: - torch.random.manual_seed(self.random_seed) - - for i in range(len(prompts)): - prompt = prompts[i] - encoder_prompt = encoder_prompts[i] if encoder_prompts is not None else None - prompt_tokens = self.text_generation_controller.tokenize_prompt(prompt, add_BOS) - - self.scheduler.add_request( - prompt=prompt, - prompt_tokens=prompt_tokens, - encoder_prompt=encoder_prompt, - inference_parameters=sampling_params, - ) - - self.run_engine() - - result: List[InferenceRequest] = self.scheduler.completed_request_pool.values() - return result - - def run_engine(self): - """Main functionality to run inference - - Runs the engine until there are no requests in the queue. - - Args: - dynamic_generation (bool, optional): Set this to True, if you want - to enable dynamic batching. Mainly used with an inference server. - Defaults to False. - """ - while self.scheduler.have_requests_pending(): - active_requests: Dict[int, InferenceRequest] = self.scheduler.active_request_pool.copy() - result_dict: Dict[int, InferenceRequest] = ( - self.text_generation_controller.generate_all_output_tokens_static_batch( - active_requests - ) - ) - - self.scheduler.update_requests_pools(result_dict=result_dict) - - # TODO: Later for dynamic batching we will do something like this - """ - if dynamic_batching: - result_dict: Dict[ - int, InferenceRequest - ] = self.text_generation_controller.generate_output_tokens_one_step_dynamic_batch( - active_requests - ) - self.scheduler.update_requests_pools(result_dict=result_dict) - """ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import asyncio +import warnings +from collections import OrderedDict +from typing import AsyncGenerator, Dict, List, Optional, Union + +import torch + +from megatron.core.inference.async_stream import AsyncStream +from megatron.core.inference.engines.abstract_engine import AbstractEngine +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.scheduler import Scheduler +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) + + +class MCoreEngine(AbstractEngine): + """The Megatron core backend constructor + + This is the backend that does a simple forward pass on the model. + Supports any model that is callable (Accepts the inputs and outputs the tensor) + + Args: + text_generation_controller (TextGenerationController): A text generation + controller that will be used to define how to preprocess prompts, generate + outputs and detokenizer the output tokens. + max_batch_size (int, optional): The maximum number of requests to process at once. + Will be set from the InferenceWrapperConfig in `text_generation_controller` by + default. + random_seed (int, optional): Use a random seed if you want deterministic + results. Defaults to None. + """ + + def __init__( + self, + text_generation_controller: TextGenerationController, + max_batch_size: Optional[int] = None, + random_seed: Optional[int] = None, + ): + inference_wrapper_config = ( + text_generation_controller.inference_wrapped_model.inference_wrapper_config + ) + inference_max_batch_size = inference_wrapper_config.inference_max_requests + if max_batch_size is None: + max_batch_size = inference_max_batch_size + elif max_batch_size > inference_max_batch_size: + warnings.warn( + f"Engine `max_batch_size` ({max_batch_size}) > " + f"`inference_max_requests` in `inference_wrapper_config` " + f"({inference_max_batch_size}); setting `max_batch_size` to " + f"{inference_max_batch_size}", + UserWarning, + ) + max_batch_size = inference_max_batch_size + self.text_generation_controller = text_generation_controller + self.random_seed = random_seed + self.scheduler = Scheduler(max_batch_size=max_batch_size) + + def get_new_request_id(self) -> str: + """Gets a new request id from the scheduler""" + return self.scheduler.get_new_request_id() + + def add_request( + self, + prompt: Optional[str] = None, + add_BOS: bool = False, + encoder_prompt: Optional[str] = None, + inference_parameters: Optional[SamplingParams] = None, + streaming: bool = False, + inference_request: Optional[InferenceRequest] = None, + ) -> str: + """ + Adds a request to the scheduler and returns the request ID. + + Args: + prompt (str): A prompt string + add_BOS (bool): Whether to add BOS token to beginning of the prompt + encoder_prompt (str): The encoder prompt string + inference_parameters (SamplingParams): The inference parameters + streaming (bool): Whether to stream incremental outputs for this request + inference_request (InferenceRequest, optional): A fully constructed request. + Defaults to None. + + Returns: + The newly created request ID. + """ + assert ( + prompt is not None or inference_request is not None + ), f"At least one of `prompt` or `inference_request` must be specified" + + if inference_request is None: + prompt_tokens = self.text_generation_controller.tokenize_prompt(prompt, add_BOS) + else: + prompt_tokens = inference_request.prompt_tokens + + return self.scheduler.add_request( + prompt=prompt, + prompt_tokens=prompt_tokens, + encoder_prompt=encoder_prompt, + inference_parameters=inference_parameters, + streaming=streaming, + inference_request=inference_request, + ) + + def get_stream_generator( + self, request_id: str + ) -> Union[AsyncGenerator[InferenceRequest, None], None]: + """Returns the stream generator for the given request ID if it exists.""" + stream = self.scheduler.streams.get(request_id, None) + if stream is not None: + return stream.generator() + return None + + def generate( + self, + prompts: Optional[List[str]] = None, + add_BOS: bool = False, + encoder_prompts: Optional[List[str]] = None, + common_inference_params: Optional[SamplingParams] = None, + sampling_params: Optional[SamplingParams] = None, + inference_requests: Optional[List[InferenceRequest]] = None, + ) -> List[InferenceRequest]: + """The megatron core inference backend generate function + + This backend returns the output generations as a dictionary. + It returns the prompt tokens along with the generated tokens, the prompt + plus the generated string and the output log probabilities if requested + + Args: + prompts (List[str]): All the prompts as a list of strings + add_BOS (bool): Whether to add BOS token to beginning of prompts + encoder_prompts (List[dict]): All the encoder prompts as a list of strings + common_inference_params: Deprecated. Only used for backward compatibility with + MCore <= 0.9.0. Use `sampling_params` going forward. + sampling_params (SamplingParams): The request-level sampling parameters + inference_requests (List[InferenceRequest]): A pre-populated list of inference requests + + Returns: + List[InferenceRequest]: The output is list of inference requests containing the + generated tokens, texts and log probs if required + """ + # TODO :M core- get rng state tracker + + request_ids: List[str] = [] + + if self.random_seed: + torch.random.manual_seed(self.random_seed) + + if inference_requests is None: + assert prompts is not None + + if common_inference_params: + sampling_params = common_inference_params + + for i in range(len(prompts)): + prompt = prompts[i] + encoder_prompt = encoder_prompts[i] if encoder_prompts is not None else None + request_id = self.add_request( + prompt=prompt, + encoder_prompt=encoder_prompt, + inference_parameters=sampling_params, + ) + request_ids.append(request_id) + else: + for inference_request in inference_requests: + request_ids.append(inference_request.request_id) + self.scheduler.add_request(inference_request=inference_request) + + self.run_engine() + + result: List[InferenceRequest] = [ + self.scheduler.completed_request_pool[request_id] for request_id in request_ids + ] + return result + + def run_engine(self): + """Main functionality to run inference + + Runs the engine until there are no requests in the queue. + + Args: + dynamic_generation (bool, optional): Set this to True, if you want + to enable dynamic batching. Mainly used with an inference server. + Defaults to False. + """ + while self.scheduler.have_requests_pending(): + active_requests: Dict[str, InferenceRequest] = self.scheduler.active_request_pool.copy() + active_streams: Dict[str, AsyncStream] = OrderedDict() + for request_id in active_requests: + if (stream := self.scheduler.streams.get(request_id, None)) is not None: + assert isinstance(stream, AsyncStream), stream + active_streams[request_id] = stream + result_dict: Dict[str, InferenceRequest] = ( + self.text_generation_controller.generate_all_output_tokens_static_batch( + active_requests, active_streams + ) + ) + + self.scheduler.update_requests_pools(result_dict=result_dict) + + # TODO: Later for dynamic batching we will do something like this + """ + if dynamic_batching: + result_dict: Dict[ + str, InferenceRequest + ] = self.text_generation_controller.generate_output_tokens_one_step_dynamic_batch( + active_requests + ) + self.scheduler.update_requests_pools(result_dict=result_dict) + """ + + def _wrapped_run_engine(self, cuda_device): + """ + Explicitly sets the CUDA device before running the engine. + + This is to ensure that the CUDA device is correctly propagated when running + in a new thread context. + """ + torch.cuda.set_device(cuda_device) + self.run_engine() + + async def run_engine_async(self): + """Runs the engine asynchronously using asyncio""" + loop = asyncio.get_running_loop() + + await loop.run_in_executor(None, self._wrapped_run_engine, torch.cuda.current_device()) diff --git a/megatron/core/inference/inference_request.py b/megatron/core/inference/inference_request.py index ea0d67bfea26112db6219f53a3fcbec244e58ca3..398a99aeb46ae4ea4bcb837470b4a8d9733ccd38 100644 --- a/megatron/core/inference/inference_request.py +++ b/megatron/core/inference/inference_request.py @@ -1,39 +1,52 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from dataclasses import dataclass -from enum import Enum -from typing import List - -import torch - -from megatron.core.inference.sampling_params import SamplingParams - - -# class syntax -class Status(Enum): - """Enum for status""" - - WAITING_IN_QUEUE = 1 - ACTIVE_AND_GENERATING_TOKENS = 2 - ACTIVE_BUT_NOT_GENERATING_TOKENS = 3 - COMPLETED = 4 - - -@dataclass -class InferenceRequest: - """Class for one inference request - - Containing relevant data for an inference request - - """ - - request_id: str - prompt: str - inference_parameters: SamplingParams - prompt_tokens: List[int] - arrival_time: float - status: Status - encoder_prompt: str = None - generated_text: str = None - generated_tokens: torch.Tensor = None - generated_log_probs: torch.Tensor = None - generated_length: int = 0 +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +import torch + +from megatron.core.inference.sampling_params import SamplingParams + + +# class syntax +class Status(Enum): + """Enum for status""" + + WAITING_IN_QUEUE = 1 + ACTIVE_AND_GENERATING_TOKENS = 2 + ACTIVE_BUT_NOT_GENERATING_TOKENS = 3 + COMPLETED = 4 + + +@dataclass(kw_only=True) +class InferenceRequest: + """Class for one inference request + + Containing relevant data for an inference request + + """ + + request_id: str + prompt: str + inference_parameters: Optional[SamplingParams] = None + prompt_tokens: Optional[List[int]] = None + arrival_time: Optional[float] = None + status: Optional[Status] = None + encoder_prompt: Optional[str] = None + generated_text: Optional[str] = None + segments: Optional[List[str]] = None + generated_segments: Optional[List[str]] = None + generated_sequence_lengths: Optional[List[int]] = None + generated_tokens: Optional[torch.Tensor] = None + generated_log_probs: Optional[torch.Tensor] = None + generated_length: Optional[int] = None + + +@dataclass(kw_only=True) +class VLMInferenceRequest(InferenceRequest): + """Class for a VLM inference request""" + + num_img_embeddings_per_tile: int + imgs: torch.Tensor + num_tiles: torch.Tensor + decoder_seq_length: int diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py index 647c4d191059edfb3310f50fcbcf8831473eda3c..fbaa94cc9d0ed458394c7bce327d3419ca584fa4 100644 --- a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -1,238 +1,315 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import abc -import math -from typing import Iterable, List, Union - -import torch - -from megatron.core import parallel_state, tensor_parallel -from megatron.core.inference.communication_utils import ( - recv_from_prev_pipeline_rank_, - send_to_next_pipeline_rank, -) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.inference_params import InferenceParams -from megatron.core.models.gpt.gpt_model import GPTModel - - -# pylint: disable=line-too-long -class AbstractModelInferenceWrapper(abc.ABC): - """Abstract inference wrapper - - Extend this to create a version for your model. - """ - - def __init__( - self, - model: Union['LegacyGPTModel', GPTModel], - inference_wrapper_config: InferenceWrapperConfig, - ): - """Constructor for the model inference wrapper - - The wrapper prepares the model for inference, provides the required input data and runs the forward pass. - - Args: - model (Union[GPTModel, LegacyGPTModel]): The actual GPT model (MCore or MLM) - inference_wrapper_config (InferenceWrapperConfig): Has info like hidden size, vocab size etc. - """ - assert not isinstance( - model, Iterable - ), 'interleaving schedule is not supported for inference' - self.model = model - self.inference_wrapper_config = inference_wrapper_config - self.pipeline_communication_dtype = ( - torch.float - if self.inference_wrapper_config.fp32_residual_connection - else self.inference_wrapper_config.params_dtype - ) - - def prep_model_for_inference(self, prompts_tokens: torch.Tensor): - """A utility function for preparing model for inference - - The function gets called once before the auto regressive inference loop. It puts the model in eval mode , and gets some model and inference data parameters. Extend this to build position ids ,attention mask etc, so that required slices can be extracted during the forward pass. - - Args: - prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] - - """ - self.model.eval() - - # For TP only model both is_pp_first_stage and _is_pp_last_stage returns True - self.model_is_pipeline_parallel = not ( - parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() - ) - self.prompts_tokens = prompts_tokens - batch_size, max_sequence_length = self.prompts_tokens.shape - self.inference_params = InferenceParams(batch_size, max_sequence_length) - - @abc.abstractmethod - def get_batch_for_context_window(self) -> List: - """Returns the input data for inference - - This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference. - - """ - pass - - def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor: - """Utility to carry out simple forward pass for TP or no model parallel models - - Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism. - - Args: - inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask] - - Returns: - torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] - """ - tokens, position_ids, attention_mask = inference_input - logits = self.model( - tokens, position_ids, attention_mask, inference_params=self.inference_params - ) - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - self.inference_params.sequence_len_offset += tokens.size(1) - - return logits - - def _allocate_recv_buffer(self, batch_size, seq_len): - """Receive happens between the layers with size [seq_len, batch_size, hidden_size].""" - recv_size = (seq_len, batch_size, self.inference_wrapper_config.hidden_size) - return torch.empty( - recv_size, dtype=self.pipeline_communication_dtype, device=torch.cuda.current_device() - ) - - def forward_pass_with_pipeline_parallel_small_input_batch( - self, inference_input: List - ) -> torch.Tensor: - """Utility to carry out forward pass for PP models with very small inputs - - If a model is pipeline parallel, yet, the input global batch is very small, we compute a foward pass on the entire global batch, rather than splitting it up into micro batches and doing something more complex as in the forward_pass_with_pipeline_parallel_large_input_batch method - - Args: - inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask] - - Returns: - torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] - """ - tokens, position_ids, attention_mask = inference_input - batch_size, seq_len = tokens.shape - recv_buffer = None - if not parallel_state.is_pipeline_first_stage(): - recv_buffer = self._allocate_recv_buffer(batch_size, seq_len) - recv_from_prev_pipeline_rank_(recv_buffer) - - self.model.set_input_tensor(recv_buffer) - output_tensor = self.model( - tokens, position_ids, attention_mask, inference_params=self.inference_params - ) - - if not parallel_state.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor.type(dtype=self.pipeline_communication_dtype)) - - self.inference_params.sequence_len_offset += seq_len - - logits = None - if parallel_state.is_pipeline_last_stage(): - logits = output_tensor - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - return logits - - def forward_pass_with_pipeline_parallel_large_input_batch( - self, inference_input: List - ) -> torch.Tensor: - """Utility to carry out forward pass PP models. - - Runs the forward pass for models which are pipeline parallel. This is more complex than forward_pass_with_pipeline_parallel_small_input_batch coz this splits the global batch into small micro batches and runs them through the model. - - Args: - inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask] - - Returns: - torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] - """ - tokens, position_ids, attention_mask = inference_input - micro_batch_size = max( - 1, - self.inference_wrapper_config.inference_batch_times_seqlen_threshold // tokens.size(1), - ) - batch_size, seq_len = tokens.shape - # Round up to account for the last partial micro batch if present - num_micro_batches = math.ceil(batch_size / micro_batch_size) - - logits = None - # Preallocate memory for output logits. - if parallel_state.is_pipeline_last_stage(): - logits = torch.empty( - (batch_size, seq_len, self.inference_wrapper_config.padded_vocab_size), - dtype=torch.float32, - device=torch.cuda.current_device(), - ) - - recv_buffer = None - if not parallel_state.is_pipeline_first_stage(): - recv_buffer = self._allocate_recv_buffer(micro_batch_size, seq_len) - for micro_batch_index in range(num_micro_batches): - start = micro_batch_index * micro_batch_size - end = min(start + micro_batch_size, batch_size) - tokens2use = tokens[start:end, ...] - position_ids2use = position_ids[start:end, ...] - current_micro_batch_size = end - start - - # Need to change recv buffer shape for the last partial microbatch (if exists) - if current_micro_batch_size != micro_batch_size: - recv_buffer = self._allocate_recv_buffer(current_micro_batch_size, seq_len) - - if not parallel_state.is_pipeline_first_stage(): - recv_from_prev_pipeline_rank_(recv_buffer) - - self.model.set_input_tensor(recv_buffer) - output_tensor = self.model( - tokens2use, position_ids2use, attention_mask, inference_params=self.inference_params - ) - - if not parallel_state.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor) - - self.inference_params.batch_size_offset += current_micro_batch_size - - if parallel_state.is_pipeline_last_stage(): - output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region( - output_tensor - ) - logits[start:end, ...] = output_tensor - - # Once done with all micro batches, we reset batch size offset and seq len offset - self.inference_params.sequence_len_offset += seq_len - self.inference_params.batch_size_offset = 0 - - # NOTE: Only returns the logits on the last pipeline stage - return logits - - def run_one_forward_step(self, inference_input: List) -> torch.Tensor: - """The forward pass of the model for inference - - Appropriate utility is called for the forward pass depending on the type of model parallelism used - - Args: - inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask] - - Returns: - torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models. - """ - if self.model_is_pipeline_parallel: - tokens = inference_input[0] - current_batch_size, seq_len = tokens.shape - # If input batch is large, we need to split into micro batches and run the forward pass - if ( - current_batch_size * seq_len - > self.inference_wrapper_config.inference_batch_times_seqlen_threshold - ): - return self.forward_pass_with_pipeline_parallel_large_input_batch(inference_input) - else: - # If input batch is very small we can do a simple forward pass on the entire global batch - return self.forward_pass_with_pipeline_parallel_small_input_batch(inference_input) - else: - return self.forward_pass_without_pipeline_parallel(inference_input) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import abc +import math +from typing import Any, Dict, Iterable, Optional, Union + +import torch + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.inference.communication_utils import ( + recv_from_prev_pipeline_rank_, + send_to_next_pipeline_rank, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.inference_params import InferenceParams +from megatron.core.models.gpt.gpt_model import GPTModel + + +# pylint: disable=line-too-long +class AbstractModelInferenceWrapper(abc.ABC): + """Abstract inference wrapper + + Extend this to create a version for your model. + """ + + def __init__( + self, + model: Union['LegacyGPTModel', GPTModel], # type: ignore[name-defined] + inference_wrapper_config: InferenceWrapperConfig, + ): + """Constructor for the model inference wrapper + + The wrapper prepares the model for inference, provides the required input data and runs the forward pass. + + Args: + model (Union[GPTModel, LegacyGPTModel]): The actual GPT model (MCore or MLM) + inference_wrapper_config (InferenceWrapperConfig): Has info like hidden size, vocab size etc. + """ + assert not isinstance( + model, Iterable + ), 'interleaving schedule is not supported for inference' + self.model = model + self.inference_wrapper_config = inference_wrapper_config + self.pipeline_communication_dtype = ( + torch.float + if self.inference_wrapper_config.fp32_residual_connection + else self.inference_wrapper_config.params_dtype + ) + + max_batch_size = self.inference_wrapper_config.inference_max_requests + max_sequence_length = self.inference_wrapper_config.inference_max_seq_length + self.inference_params = InferenceParams(max_batch_size, max_sequence_length) + + def prep_model_for_inference(self, prompts_tokens: torch.Tensor): + """A utility function for preparing model for inference + + The function gets called once before the auto regressive inference loop. + It puts the model in eval mode. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + """ + self.model.eval() + + # For TP only model both is_pp_first_stage and _is_pp_last_stage returns True + self.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + self.inference_params.reset() + + @abc.abstractmethod + def prep_inference_input(self, prompt_tokens) -> Dict[str, Any]: + """Prepares the inference input data. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + Returns: + A dict with all the inference input needed for the batch. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_batch_for_context_window(self, *args, **kwargs) -> Dict[str, Any]: + """Returns the input data for inference + + This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference. + + """ + raise NotImplementedError() + + def _forward(self, inference_input): + """Runs a forward pass of the model. + + Args: + inference_input(Dict[str, Any]): The input data. + inference_params(InferenceParams): The inference parameters. + + Returns: + The model output logits. + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + return self.model( + tokens, position_ids, attention_mask, inference_params=self.inference_params + ) + + def _get_batch_size_and_seq_len( + self, tokens: torch.Tensor, recv_buffer_seq_len: Optional[int] = None + ): + """ + Returns the batch size and sequence length based on the tokens tensor and recv_buffer_seq_len. + + Args: + tokens (torch.Tensor): The input tensor of shape (batch_size, seq_len). + recv_buffer_seq_len (int, optional): An optional recv buffer sequence length. + + Returns: + tuple: A tuple (batch_size, seq_len), where batch_size is the first dimension of tokens + and seq_len is either the second dimension or recv_buffer_seq_len. + """ + batch_size = tokens.shape[0] + seq_len = recv_buffer_seq_len if recv_buffer_seq_len is not None else tokens.shape[1] + return batch_size, seq_len + + def _allocate_recv_buffer(self, batch_size, seq_len): + """Receive happens between the layers with size [seq_len, batch_size, hidden_size].""" + recv_size = (seq_len, batch_size, self.inference_wrapper_config.hidden_size) + return torch.empty( + recv_size, dtype=self.pipeline_communication_dtype, device=torch.cuda.current_device() + ) + + def forward_pass_without_pipeline_parallel( + self, inference_input: Dict[str, Any] + ) -> torch.Tensor: + """Utility to carry out simple forward pass for TP or no model parallel models + + Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism. + + Args: + inference_input (Dict[str, Any]): A dict containg the inputs for the gpt model [tokens, position ids, attention mask] + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + tokens = inference_input["tokens"] + logits = self._forward(inference_input) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + self.inference_params.sequence_len_offset += tokens.size(1) + + return logits + + def forward_pass_with_pipeline_parallel_small_input_batch( + self, inference_input: Dict[str, Any], recv_buffer_seq_len: Optional[int] = None + ) -> torch.Tensor: + """Utility to carry out forward pass for PP models with very small inputs + + If a model is pipeline parallel, yet, the input global batch is very small, we compute a foward pass on the entire global batch, rather than splitting it up into micro batches and doing something more complex as in the forward_pass_with_pipeline_parallel_large_input_batch method + + Args: + inference_input (Dict[str, Any]): A dict containing the inputs for the gpt model [tokens, position ids, attention mask] + recv_buffer_seq_len (int): An optional sequence length for the pipeline parallel recv buffer. + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + batch_size, seq_len = self._get_batch_size_and_seq_len(tokens, recv_buffer_seq_len) + recv_buffer = None + if not parallel_state.is_pipeline_first_stage(): + recv_buffer = self._allocate_recv_buffer(batch_size, seq_len) + recv_from_prev_pipeline_rank_(recv_buffer) + + self.model.set_input_tensor(recv_buffer) + output_tensor = self._forward(inference_input) + + if not parallel_state.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor.type(dtype=self.pipeline_communication_dtype)) + + self.inference_params.sequence_len_offset += seq_len + + logits = None + if parallel_state.is_pipeline_last_stage(): + logits = output_tensor + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + # Explicitly cast logits to expected dtype + logits = logits.to(self.inference_wrapper_config.params_dtype) + + return logits + + def forward_pass_with_pipeline_parallel_large_input_batch( + self, inference_input: Dict[str, Any], recv_buffer_seq_len=None + ) -> torch.Tensor: + """Utility to carry out forward pass PP models. + + Runs the forward pass for models which are pipeline parallel. + This is more complex than forward_pass_with_pipeline_parallel_small_input_batch because + this splits the global batch into small micro batches and runs them through the model. + + Args: + inference_input (Dict[str, Any]): A dict containg the inputs for the gpt model [tokens, position ids, attention mask] + recv_buffer_seq_len (int): An optional sequence length for the pipeline parallel recv buffer. + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + micro_batch_size = max( + 1, + self.inference_wrapper_config.inference_batch_times_seqlen_threshold // tokens.size(1), + ) + batch_size, seq_len = self._get_batch_size_and_seq_len(tokens, recv_buffer_seq_len) + # Round up to account for the last partial micro batch if present + num_micro_batches = math.ceil(batch_size / micro_batch_size) + + logits = None + # Preallocate memory for output logits. + if parallel_state.is_pipeline_last_stage(): + logits = torch.empty( + (batch_size, seq_len, self.inference_wrapper_config.padded_vocab_size), + dtype=self.pipeline_communication_dtype, + device=torch.cuda.current_device(), + ) + + recv_buffer = None + if not parallel_state.is_pipeline_first_stage(): + recv_buffer = self._allocate_recv_buffer(micro_batch_size, seq_len) + for micro_batch_index in range(num_micro_batches): + start = micro_batch_index * micro_batch_size + end = min(start + micro_batch_size, batch_size) + tokens2use = tokens[start:end, ...] + position_ids2use = position_ids[start:end, ...] + current_micro_batch_size = end - start + + # Need to change recv buffer shape for the last partial microbatch (if exists) + if current_micro_batch_size != micro_batch_size: + recv_buffer = self._allocate_recv_buffer(current_micro_batch_size, seq_len) + + if not parallel_state.is_pipeline_first_stage(): + recv_from_prev_pipeline_rank_(recv_buffer) + + self.model.set_input_tensor(recv_buffer) + output_tensor = self._forward( + { + "tokens": tokens2use, + "position_ids": position_ids2use, + "attention_mask": attention_mask, + } + ) + + if not parallel_state.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + + self.inference_params.batch_size_offset += current_micro_batch_size + + if parallel_state.is_pipeline_last_stage(): + output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region( + output_tensor + ) + assert logits is not None + logits[start:end, ...] = output_tensor + + # Explicitly cast logits to expected dtype + logits = logits.to(self.inference_wrapper_config.params_dtype) + + # Once done with all micro batches, we reset batch size offset and seq len offset + self.inference_params.sequence_len_offset += seq_len + self.inference_params.batch_size_offset = 0 + + # NOTE: Only returns the logits on the last pipeline stage + return logits + + def run_one_forward_step( + self, inference_input: Dict[str, Any], recv_buffer_seq_len: Optional[int] = None + ) -> torch.Tensor: + """The forward pass of the model for inference + + Appropriate utility is called for the forward pass depending on the type of model parallelism used + + Args: + inference_input (Dict[str, Any]): A dict containg the inputs for the gpt model [tokens, position ids, attention mask] + recv_buffer_seq_len (int): An optional sequence length for the pipeline parallel recv buffer. + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models. + """ + if self.model_is_pipeline_parallel: + tokens = inference_input["tokens"] + current_batch_size, seq_len = self._get_batch_size_and_seq_len( + tokens, recv_buffer_seq_len + ) + # If input batch is large, we need to split into micro batches and run the forward pass + if ( + current_batch_size * seq_len + > self.inference_wrapper_config.inference_batch_times_seqlen_threshold + ): + return self.forward_pass_with_pipeline_parallel_large_input_batch( + inference_input, recv_buffer_seq_len + ) + else: + # If input batch is very small we can do a simple forward pass on the entire global batch + return self.forward_pass_with_pipeline_parallel_small_input_batch( + inference_input, recv_buffer_seq_len + ) + else: + return self.forward_pass_without_pipeline_parallel(inference_input) diff --git a/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py index 166ed5e0672165590c9f6f8e09a3833ce4652bc5..5af4b09330a40d67f5ecbc24cfa4b45a858b477b 100644 --- a/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py @@ -1,90 +1,102 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from typing import List, Tuple - -import torch - -from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( - AbstractModelInferenceWrapper, -) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.models.gpt import GPTModel - - -# pylint: disable=line-too-long -class GPTInferenceWrapper(AbstractModelInferenceWrapper): - """Inference wrapper for GPT model""" - - def __init__(self, model: GPTModel, inference_wrapper_config: InferenceWrapperConfig): - """Constructor for the model inference wrapper - - The wrapper prepares the model for inference, provides the required input data, and runs the forward pass - - Args: - model (GPTModel): The GPT model (MCore or legacy) - inference_wrapper_config (InferenceWrapperConfig): Has info like hidden size, vocab size etc - """ - super().__init__(model, inference_wrapper_config) - - def prep_model_for_inference(self, prompts_tokens: torch.Tensor): - """A utility function for preparing model for inference - - This function is called before the forward pass. It puts the model in eval mode, builds position ids, and creates attention masks so that required slices can be extracted during the forward pass. - - Args: - prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] - """ - - super().prep_model_for_inference(prompts_tokens=prompts_tokens) - self.attention_mask, self.position_ids = self._build_attention_mask_and_position_ids( - prompts_tokens - ) - - def _build_attention_mask_and_position_ids( - self, prompts_tokens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Builds the full attention mask and position ids for the input tokens - - Args: - prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The attention mask of shape [1, 1, max_seq_len, max_seq_len] and position ids of shape [batch_size, max_seq_len] - """ - seq_length = prompts_tokens.size(1) - attention_mask = torch.tril( - torch.ones((1, seq_length, seq_length), device=prompts_tokens.device) - ).view(1, 1, seq_length, seq_length) - # Convert to boolean - attention_mask = attention_mask < 0.5 - - position_ids = ( - torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device) - .unsqueeze(0) - .expand_as(prompts_tokens) - ) - - return attention_mask, position_ids - - def get_batch_for_context_window( - self, context_start_position: int, context_end_position: int - ) -> List: - """Returns the inference data given context window - - This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data. - - Args: - context_start_position (int): Start of the context window. During the first inference step it is mostly 0 - context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length. - - Returns: - List: A list of inputs that will be used by your model in the forward step - """ - tokens2use = self.prompts_tokens[:, context_start_position:context_end_position] - positions2use = self.position_ids[:, context_start_position:context_end_position] - attention_mask2use = self.attention_mask[ - ..., context_start_position:context_end_position, :context_end_position - ] - data_at_step_idx = [tokens2use, positions2use, attention_mask2use] - return data_at_step_idx +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import Any, Dict, Tuple + +import torch + +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.models.gpt import GPTModel + + +# pylint: disable=line-too-long +class GPTInferenceWrapper(AbstractModelInferenceWrapper): + """Inference wrapper for GPT model""" + + def __init__(self, model: GPTModel, inference_wrapper_config: InferenceWrapperConfig): + """Constructor for the model inference wrapper + + The wrapper prepares the model for inference, provides the required input data, and runs the forward pass + + Args: + model (GPTModel): The GPT model (MCore or legacy) + inference_wrapper_config (InferenceWrapperConfig): Has info like hidden size, vocab size etc + """ + super().__init__(model, inference_wrapper_config) + + def prep_inference_input(self, prompts_tokens: torch.Tensor) -> Dict[str, Any]: + """Prepares the inference input data. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + Returns: + A dict with all the inference input needed for the batch. + """ + attention_mask, position_ids = self._build_attention_mask_and_position_ids(prompts_tokens) + return { + "tokens": prompts_tokens, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + def _build_attention_mask_and_position_ids( + self, prompts_tokens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Builds the full attention mask and position ids for the input tokens + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The attention mask of shape [1, 1, max_seq_len, max_seq_len] and position ids of shape [batch_size, max_seq_len] + """ + seq_length = prompts_tokens.size(1) + attention_mask = torch.tril( + torch.ones((1, seq_length, seq_length), device=prompts_tokens.device) + ).view(1, 1, seq_length, seq_length) + # Convert to boolean + attention_mask = attention_mask < 0.5 + + position_ids = ( + torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device) + .unsqueeze(0) + .expand_as(prompts_tokens) + ) + + return attention_mask, position_ids + + def get_batch_for_context_window( + self, + inference_input: Dict[str, Any], + context_start_position: int, + context_end_position: int, + ) -> Dict[str, Any]: + """Returns the inference data given context window + + This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data. + + Args: + inference_input (Dict[str, Any]): The inference input for the batch. + context_start_position (int): Start of the context window. During the first inference step it is mostly 0 + context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length. + + Returns: + Dict[str, Any]: A dict of inputs that will be used by your model in the forward step + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + tokens2use = tokens[:, context_start_position:context_end_position] + positions2use = position_ids[:, context_start_position:context_end_position] + attention_mask2use = attention_mask[ + ..., context_start_position:context_end_position, :context_end_position + ] + return { + "tokens": tokens2use, + "position_ids": positions2use, + "attention_mask": attention_mask2use, + } diff --git a/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py b/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py index 14ca0f6fee0463ed34024b4e008f997ce8d16272..a746f8ce886815b940e1797fd81dafdc886531f9 100644 --- a/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +++ b/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py @@ -1,44 +1,50 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from dataclasses import dataclass - -import torch - - -@dataclass -class InferenceWrapperConfig: - """Config for the model inference wrapper - - NOTE : All the arguments here are obtained from arguments.py file - """ - - hidden_size: int - """Receive happens between the layers during PP with size [seq_len, batch_size, hidden_size]""" - - params_dtype: torch.dtype - """Can be torch.float or torch.half if --fp16 is used, or torch.bfloat16 if --bf16 is used""" - - inference_batch_times_seqlen_threshold: int - """if (batch-size * sequence-length) is smaller than this threshold then we will not pipeline - the batch.""" - - padded_vocab_size: int - """The final padded vocab size (Padded to make it divisible by - --make-vocab-size-divisible-by value)""" - - fp32_residual_connection: bool = False - """Move residual connections to fp32. Obtained from arguments.py""" - - def add_attributes(self, attribute_value_pair: dict): - """Utility to add more attributes to inference params - - Use this method to pass in a custom dictionary to add more configs to the instance created. - Use as follows: - c = InferenceWrapperConfig - c.add_attributes({'precision':'fp32'}) - - Args: - attribute_value_pair (dict): A dictionary containing attributes as the key names and - corresponding values. - """ - for key, value in attribute_value_pair.items(): - setattr(self, key, value) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass + +import torch + + +@dataclass +class InferenceWrapperConfig: + """Config for the model inference wrapper + + NOTE : All the arguments here are obtained from arguments.py file + """ + + hidden_size: int + """Receive happens between the layers during PP with size [seq_len, batch_size, hidden_size]""" + + params_dtype: torch.dtype + """Can be torch.float or torch.half if --fp16 is used, or torch.bfloat16 if --bf16 is used""" + + inference_batch_times_seqlen_threshold: int + """if (batch-size * sequence-length) is smaller than this threshold then we will not pipeline + the batch.""" + + padded_vocab_size: int + """The final padded vocab size (Padded to make it divisible by + --make-vocab-size-divisible-by value)""" + + inference_max_requests: int = 8 + """ Maximum number of requests for inference (prefill & decode). Necessary for CUDA graphs. """ + + inference_max_seq_length: int = 2560 + """ Maximum sequence length for inference (prefill & decode). Necessary for CUDA graphs. """ + + fp32_residual_connection: bool = False + """Move residual connections to fp32. Obtained from arguments.py""" + + def add_attributes(self, attribute_value_pair: dict): + """Utility to add more attributes to inference params + + Use this method to pass in a custom dictionary to add more configs to the instance created. + Use as follows: + c = InferenceWrapperConfig + c.add_attributes({'precision':'fp32'}) + + Args: + attribute_value_pair (dict): A dictionary containing attributes as the key names and + corresponding values. + """ + for key, value in attribute_value_pair.items(): + setattr(self, key, value) diff --git a/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..96acca6c3677e72dce99c40a437b130ef83ca1bc --- /dev/null +++ b/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py @@ -0,0 +1,208 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import Any, Dict + +import torch + +from megatron.core import parallel_state +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, +) +from megatron.core.inference_params import InferenceParams + + +# pylint: disable=line-too-long +class VLMInferenceWrapper(GPTInferenceWrapper): + """Inference wrapper for VLMs""" + + def prep_model_for_inference(self, prompts_tokens: torch.Tensor): + """A utility function for preparing model for inference + + The function gets called once before the auto regressive inference loop. + It puts the model in eval mode. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + """ + super().prep_model_for_inference(prompts_tokens) + + # For TP only model both is_pp_first_stage and _is_pp_last_stage returns True + self.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + self._recv_only_vision_embeds = False + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + # Checks if the previous stage only has a vision encoder, and that the current stage + # has part of the LM decoder. In this case, the current stage should only receive + # vision embeddings. + if pp_rank > 0: + self._recv_only_vision_embeds = ( + parallel_state.is_inside_encoder(pp_rank - 1) + and (not parallel_state.is_inside_decoder(pp_rank - 1)) + and parallel_state.is_inside_decoder() + ) + + # Checks if the current stage only has a vision encoder + self._encoder_only = ( + parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder() + ) + + # For TP only model both is_pp_first_stage and _is_pp_last_stage returns True + self.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + def prep_inference_input( + self, + prompts_tokens: torch.Tensor, + num_img_embeddings_per_tile: int, + images: torch.Tensor, + num_tiles: torch.Tensor, + decoder_seq_length: int, + ): + """Prepares the inference input data. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + num_img_embeddings_per_tile (int): The number of image embeddings per tile + images (torch.Tensor): The image embeddings + num_tiles (torch.Tensor): The number of tiles for each input image + decoder_seq_length (int): The decoder sequence length + """ + inference_input = super().prep_inference_input(prompts_tokens) + + total_num_tiles = torch.sum(num_tiles).item() + num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles + + batch_size, max_sequence_length = prompts_tokens.shape + self.inference_params = InferenceParams( + batch_size, max_sequence_length + num_img_embeddings + ) + + inference_input["images"] = images + inference_input["num_tiles"] = num_tiles + inference_input["num_img_embeddings"] = num_img_embeddings + inference_input["decoder_seq_length"] = decoder_seq_length + + return inference_input + + def get_batch_for_context_window( + self, + inference_input: Dict[str, Any], + context_start_position: int, + context_end_position: int, + ) -> Dict[str, Any]: + """Returns the inference data given context window + + This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data. + + Args: + inference_input (Dict[str, Any]): The inference input for the batch. + context_start_position (int): Start of the context window. During the first inference step it is mostly 0 + context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length. + + Returns: + Dict[str, Any]: A dict of inputs that will be used by your model in the forward step + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + images = inference_input["images"] + num_tiles = inference_input["num_tiles"] + num_img_embeddings = inference_input["num_img_embeddings"] + decoder_seq_length = inference_input["decoder_seq_length"] + + tokens2use = tokens[:, context_start_position:context_end_position] + positions2use = position_ids[:, context_start_position:context_end_position] + + return { + "tokens": tokens2use, + "position_ids": positions2use, + "images": images, + "num_tiles": num_tiles, + "num_img_embeddings": num_img_embeddings, + "decoder_seq_length": decoder_seq_length, + } + + def _forward(self, inference_input: Dict[str, Any]): + """Runs a forward pass of the model. + + Args: + inference_input(Dict[str, Any]): The input data. + + Returns: + The model output logits. + """ + images = inference_input["images"] + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + num_image_tiles = inference_input["num_tiles"] + + output = self.model( + images, + tokens, + position_ids=position_ids, + attention_mask=None, + inference_params=self.inference_params, + num_image_tiles=num_image_tiles, + runtime_gather_output=True, + ) + if isinstance(output, tuple): + logits, _ = output + else: + logits = output + return logits + + def run_one_forward_step(self, inference_input: Dict[str, Any]) -> torch.Tensor: + tokens = inference_input["tokens"] + num_image_tokens = (tokens == self.model.module.image_token_index).sum().item() + num_img_embeddings = inference_input["num_img_embeddings"] + decoder_seq_length = inference_input["decoder_seq_length"] + num_tokens = tokens.size(1) + recv_buffer_seq_len = None + if num_image_tokens > 0: + # When there are image tokens and this stage only receives vision embeddings, + # adjust the recv buffer seq length to match the image embeddings sequence length. + # If there are image tokens and this stage receives full embeddings, make sure we + # compensate for expansion of image tokens. + # Note that this will set a recv_buffer_seq_len for the encoder stage, + # this length is irrelevant since that recv buffer is never allocated. + if self._recv_only_vision_embeds: + recv_buffer_seq_len = num_img_embeddings + else: + recv_buffer_seq_len = min( + num_img_embeddings + num_tokens - num_image_tokens, decoder_seq_length + ) + elif self._recv_only_vision_embeds: + # If this stage only receives vision embeddings and there are no image tokens + # we won't run the encoder and therefore shouldn't try to recv. + recv_buffer_seq_len = 0 + + # If the pipeline stage only has a vision encoder, then it only needs to + # run when there are image tokens + if not (self._encoder_only and num_image_tokens == 0): + output = super().run_one_forward_step( + inference_input, recv_buffer_seq_len=recv_buffer_seq_len + ) + else: + output = None + logits = output + + # On the first inference iteration, we compute image tokens. + # On every PP stage(although inference params should only matter for decoder), + # update the sequence length offset by the number of image tokens. + if num_tokens > 1 and num_image_tokens > 0: + if "image_tokens_count" not in self.inference_params.key_value_memory_dict: + self.inference_params.key_value_memory_dict["image_tokens_count"] = ( + num_img_embeddings + ) + + if num_img_embeddings + num_tokens - num_image_tokens > decoder_seq_length: + self.inference_params.sequence_len_offset += decoder_seq_length - num_tokens + else: + self.inference_params.sequence_len_offset += ( + self.inference_params.key_value_memory_dict["image_tokens_count"] + - num_image_tokens + ) + + return logits diff --git a/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py index 2e5f8466d7d650dd06b8be0f1cec1d581c898b50..f076528356e51e53b7147080984f4feb760a90a7 100644 --- a/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py @@ -1,215 +1,225 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from collections import deque -from typing import Any, List, Tuple - -import numpy -import torch - -from megatron.core import tensor_parallel -from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset -from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( - AbstractModelInferenceWrapper, -) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) -from megatron.core.models.T5 import T5Model - - -# pylint: disable=line-too-long -class T5InferenceWrapper(AbstractModelInferenceWrapper): - """Constructor for the model inference wrapper - - The wrapper prepares the model for inference, provides the required input - data, and runs the forward pass - - Args: - model (T5Model): The T5 model (MCore or legacy) - inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed - use_local (bool): Whether the T5 model's transformer impl - is local (vs transformer_engine) - """ - - def __init__( - self, - model: T5Model, - inference_wrapper_config: InferenceWrapperConfig, - use_local: bool = False, - ): - super().__init__(model, inference_wrapper_config) - self.use_local = use_local - - def prep_model_for_inference( - self, prompts_tokens: torch.Tensor, encoder_prompts: List[str] = None, tokenizer: Any = None - ): - """A utility function for preparing model for inference - - This function is called before the forward pass. It puts the model in eval mode, builds - position ids, and creates attention masks so that required slices can be extracted during - the forward pass. - - Args: - prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] - encoder_prompts (dict): List of string of encoder input prompts - tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text - """ - - super().prep_model_for_inference(prompts_tokens=prompts_tokens) - - # get max_sequence_length - if hasattr(self.model, "module"): # if self.model is Float16Module - max_sequence_length = self.model.module.max_sequence_length - else: - max_sequence_length = self.model.max_sequence_length - - encoder_prompts_tokens_list = [ - self.tokenize_encoder_prompt(encoder_prompt, tokenizer) - for encoder_prompt in encoder_prompts - ] - self.batch_encoder_prompts_tokens = self.pad_encoder_prompts_tokens( - encoder_prompts_tokens_list, max_sequence_length, tokenizer - ) - - # create batch mask for encoder_prompt (self.batch_input_tokens) and - # decoder_input (self.prompts_tokens), similar to megatron/core/datasets/t5_dataset.py - decoder_prompts_tokens = self.prompts_tokens.cpu().numpy() - encoder_prompts_tokens = self.batch_encoder_prompts_tokens.cpu().numpy() - self.batch_mask_encoder = [] - self.batch_mask_decoder = [] - for i in range(len(self.prompts_tokens)): - mask_encoder = encoder_prompts_tokens[i] == tokenizer.pad - mask_decoder = decoder_prompts_tokens[i] == tokenizer.pad - self.batch_mask_encoder.append(mask_encoder) - self.batch_mask_decoder.append(mask_decoder) - self.batch_mask_encoder = torch.tensor(numpy.array(self.batch_mask_encoder)).cuda() - self.batch_mask_decoder = torch.tensor(numpy.array(self.batch_mask_decoder)).cuda() - - def tokenize_encoder_prompt( - self, encoder_prompt: str, tokenizer - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Utility to tokenize the encoder_prompt - - Args: - encoder_prompt (str): The encoder_prompt - tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string - - Returns: - torch.Tensor: Returns the tokenized prompt - """ - - # if there is the word "" in prompt, replacing it with special_additional_token, - # similar to processing step in megatron/core/datasets/t5_dataset.py - divided_encoder_prompt_list = encoder_prompt.split("") - masks_count = len(divided_encoder_prompt_list) - 1 - sentinels = deque(tokenizer.additional_special_tokens_ids) - - encoder_prompt_tokens = [] - for divided_encoder_prompt in divided_encoder_prompt_list: - divided_encoder_prompt_tokens = tokenizer.tokenize(divided_encoder_prompt) - encoder_prompt_tokens.extend(divided_encoder_prompt_tokens) - if masks_count > 0: - sentinel = sentinels.popleft() - encoder_prompt_tokens.extend([sentinel]) - masks_count -= 1 - - return encoder_prompt_tokens - - def pad_encoder_prompts_tokens( - self, encoder_prompts_tokens_list: List[List[int]], max_sequence_length: int, tokenizer - ) -> torch.Tensor: - """Method to pad input prompts - - Given a list of prompts, pad them all to uniform length - - Args: - encoder_prompts_tokens_list (List[List[int]]): A list containing the - encoder_input_tokens - max_sequence_length (int): Maximum of the length of the encoder inputs tokens - tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text - - Returns: - torch.Tensor: A torch tensor of shape [bs, max_sequence_length] - """ - - for encoder_prompt_tokens in encoder_prompts_tokens_list: - padding_size = max_sequence_length - len(encoder_prompt_tokens) - encoder_prompt_tokens.extend([tokenizer.pad] * padding_size) - - return torch.tensor(encoder_prompts_tokens_list).cuda() - - def get_batch_for_context_window( - self, context_start_position: int, context_end_position: int - ) -> List: - """Returns the inference data given context window - - This function gets called iteratively in a loop . Given the start and end context - positions , it extracts the appropriate data. - - Args: - context_start_position (int): Start of the context window. During - the first inference step it is mostly 0 - context_end_position (int): End of the context window. During the - last inference step it will mostly be the max generated sequence length. - - Returns: - List: A list of inputs that will be used by your model in the forward step - """ - - # T5 inference not yet support kv_cache - encoder_tokens2use = self.batch_encoder_prompts_tokens - decoder_tokens2use = self.prompts_tokens[:, :context_end_position] - encoder_mask2use = self.batch_mask_encoder - decoder_mask2use = self.batch_mask_decoder[:, :context_end_position] - - # Configure attention mask based on different conditions - # (e.g., transformer-impl, TE versions, TE backends) - [encoder_mask2use, decoder_mask2use, encoder_decoder_mask2use] = ( - T5MaskedWordPieceDataset.config_attention_mask( - encoder_tokens2use, - decoder_tokens2use, - encoder_mask2use, - decoder_mask2use, - self.use_local, - ) - ) - - data_at_step_idx = [ - encoder_tokens2use, - decoder_tokens2use, - encoder_mask2use, - decoder_mask2use, - encoder_decoder_mask2use, - ] - - return data_at_step_idx - - def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor: - """Utility to carry out simple forward pass for TP or no model parallel models - - Runs a very simple forward pass for model. Used in the case of models without - any parallelism or only tensor parallelism. - - Args: - inference_input (List): A list containg the inputs for the gpt - model [tokens, position ids, attention mask] - - Returns: - torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] - """ - [encoder_tokens, decoder_tokens, encoder_mask, decoder_mask, encoder_decoder_mask] = ( - inference_input - ) - tokens = decoder_tokens - - # T5 inference not yet support kv_cache - logits = self.model( - encoder_tokens, - decoder_tokens, - encoder_mask, - decoder_mask, - encoder_decoder_mask, - inference_params=None, - ) - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - return logits +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from collections import deque +from typing import Any, Dict, List, Optional + +import numpy +import torch + +from megatron.core import tensor_parallel +from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.models.T5 import T5Model +from megatron.core.utils import get_attr_wrapped_model + + +# pylint: disable=line-too-long +class T5InferenceWrapper(AbstractModelInferenceWrapper): + """Constructor for the model inference wrapper + + The wrapper prepares the model for inference, provides the required input + data, and runs the forward pass + + Args: + model (T5Model): The T5 model (MCore or legacy) + inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed + use_local (bool): Whether the T5 model's transformer impl + is local (vs transformer_engine) + """ + + def __init__( + self, + model: T5Model, + inference_wrapper_config: InferenceWrapperConfig, + use_local: bool = False, + ): + super().__init__(model, inference_wrapper_config) + self.use_local = use_local + + def prep_inference_input( + self, + prompts_tokens: torch.Tensor, + encoder_prompts: Optional[List[str]] = None, + tokenizer: Any = None, + ) -> Dict[str, Any]: + """Prepares the inference input data. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + encoder_prompts (dict): List of string of encoder input prompts + tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text + + Returns: + A dict with all the inference input needed for the batch. + """ + # get max_sequence_length + max_sequence_length = get_attr_wrapped_model(self.model, "max_sequence_length") + + encoder_prompts_tokens_list = [ + self.tokenize_encoder_prompt(encoder_prompt, tokenizer) + for encoder_prompt in encoder_prompts + ] + batch_encoder_prompts_tokens = self.pad_encoder_prompts_tokens( + encoder_prompts_tokens_list, max_sequence_length, tokenizer + ) + + # create batch mask for encoder_prompt (self.batch_input_tokens) and + # decoder_input (prompts_tokens), similar to megatron/core/datasets/t5_dataset.py + decoder_prompts_tokens = prompts_tokens + encoder_prompts_tokens = batch_encoder_prompts_tokens + decoder_prompts_tokens_numpy = decoder_prompts_tokens.cpu().numpy() + encoder_prompts_tokens_numpy = encoder_prompts_tokens.cpu().numpy() + batch_mask_encoder = [] + batch_mask_decoder = [] + for i in range(len(prompts_tokens)): + mask_encoder = encoder_prompts_tokens_numpy[i] == tokenizer.pad + mask_decoder = decoder_prompts_tokens_numpy[i] == tokenizer.pad + batch_mask_encoder.append(mask_encoder) + batch_mask_decoder.append(mask_decoder) + batch_mask_encoder = torch.tensor(numpy.array(batch_mask_encoder)).cuda() + batch_mask_decoder = torch.tensor(numpy.array(batch_mask_decoder)).cuda() + + return { + "encoder_tokens": encoder_prompts_tokens, + "decoder_tokens": decoder_prompts_tokens, + "encoder_mask": batch_mask_encoder, + "decoder_mask": batch_mask_decoder, + } + + def tokenize_encoder_prompt(self, encoder_prompt: str, tokenizer) -> torch.Tensor: + """Utility to tokenize the encoder_prompt + + Args: + encoder_prompt (str): The encoder_prompt + tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string + + Returns: + torch.Tensor: Returns the tokenized prompt + """ + + # if there is the word "" in prompt, replacing it with special_additional_token, + # similar to processing step in megatron/core/datasets/t5_dataset.py + divided_encoder_prompt_list = encoder_prompt.split("") + masks_count = len(divided_encoder_prompt_list) - 1 + sentinels = deque(tokenizer.additional_special_tokens_ids) + + encoder_prompt_tokens = [] + for divided_encoder_prompt in divided_encoder_prompt_list: + divided_encoder_prompt_tokens = tokenizer.tokenize(divided_encoder_prompt) + encoder_prompt_tokens.extend(divided_encoder_prompt_tokens) + if masks_count > 0: + sentinel = sentinels.popleft() + encoder_prompt_tokens.extend([sentinel]) + masks_count -= 1 + + return encoder_prompt_tokens + + def pad_encoder_prompts_tokens( + self, encoder_prompts_tokens_list: List[List[int]], max_sequence_length: int, tokenizer + ) -> torch.Tensor: + """Method to pad input prompts + + Given a list of prompts, pad them all to uniform length + + Args: + encoder_prompts_tokens_list (List[List[int]]): A list containing the + encoder_input_tokens + max_sequence_length (int): Maximum of the length of the encoder inputs tokens + tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text + + Returns: + torch.Tensor: A torch tensor of shape [bs, max_sequence_length] + """ + + for encoder_prompt_tokens in encoder_prompts_tokens_list: + padding_size = max_sequence_length - len(encoder_prompt_tokens) + encoder_prompt_tokens.extend([tokenizer.pad] * padding_size) + + return torch.tensor(encoder_prompts_tokens_list).cuda() + + def get_batch_for_context_window( + self, + inference_input: Dict[str, Any], + context_start_position: int, + context_end_position: int, + ) -> Dict[str, Any]: + """Returns the inference data given context window + + This function gets called iteratively in a loop . Given the start and end context + positions , it extracts the appropriate data. + + Args: + inference_input (Dict[str, Any]): The inference input for the batch. + context_start_position (int): Start of the context window. During + the first inference step it is mostly 0 + context_end_position (int): End of the context window. During the + last inference step it will mostly be the max generated sequence length. + + Returns: + Dict: A dict of inputs that will be used by your model in the forward step + """ + + # T5 inference not yet support kv_cache + encoder_tokens2use = inference_input["encoder_tokens"] + decoder_tokens2use = inference_input["decoder_tokens"][:, :context_end_position] + encoder_mask2use = inference_input["encoder_mask"] + decoder_mask2use = inference_input["decoder_mask"][:, :context_end_position] + + # Configure attention mask based on different conditions + # (e.g., transformer-impl, TE versions, TE backends) + [encoder_mask2use, decoder_mask2use, encoder_decoder_mask2use] = ( + T5MaskedWordPieceDataset.config_attention_mask( + encoder_tokens2use, + decoder_tokens2use, + encoder_mask2use, + decoder_mask2use, + self.use_local, + ) + ) + + return { + "encoder_tokens": encoder_tokens2use, + "decoder_tokens": decoder_tokens2use, + "encoder_mask": encoder_mask2use, + "decoder_mask": decoder_mask2use, + "encoder_decoder_mask": encoder_decoder_mask2use, + } + + def forward_pass_without_pipeline_parallel( + self, inference_input: Dict[str, Any] + ) -> torch.Tensor: + """Utility to carry out simple forward pass for TP or no model parallel models + + Runs a very simple forward pass for model. Used in the case of models without + any parallelism or only tensor parallelism. + + Args: + inference_input (Dict[str, Any]): A dict containg the inputs for the gpt + model [tokens, position ids, attention mask] + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + encoder_tokens = inference_input["encoder_tokens"] + decoder_tokens = inference_input["decoder_tokens"] + encoder_mask = inference_input["encoder_mask"] + decoder_mask = inference_input["decoder_mask"] + encoder_decoder_mask = inference_input["encoder_decoder_mask"] + tokens = decoder_tokens + + # T5 inference not yet support kv_cache + logits = self.model( + encoder_tokens, + decoder_tokens, + encoder_mask, + decoder_mask, + encoder_decoder_mask, + inference_params=None, + ) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + return logits diff --git a/megatron/core/inference/modelopt_support/__init__.py b/megatron/core/inference/modelopt_support/__init__.py index f8eb8f3d9f9fea05e9ae61884ec3cb6787a7c07b..4da05305a8813ebf59f6833e634d82019efef0f2 100644 --- a/megatron/core/inference/modelopt_support/__init__.py +++ b/megatron/core/inference/modelopt_support/__init__.py @@ -1,8 +1,10 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt). - -ModelOpt is a library comprising state-of-the-art model optimization techniques including quantization and sparsity to -compress model for efficient inference on NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless -experience for users to optimize their Megatron-core models for inference. More details on ModelOpt including -installation and usage can be found at https://github.com/NVIDIA/TensorRT-Model-Optimizer. -""" +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt). + +ModelOpt is a library comprising state-of-the-art model optimization techniques +including quantization and sparsity to compress model for efficient inference on +NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless +experience for users to optimize their Megatron-core models for inference. +More details on ModelOpt including installation and usage can be found at +https://github.com/NVIDIA/TensorRT-Model-Optimizer. +""" diff --git a/megatron/core/inference/modelopt_support/gpt/model_specs.py b/megatron/core/inference/modelopt_support/gpt/model_specs.py index 4d422bc2f372a0a639115af4db5aba39a0996282..b11232ab869797bfcedfa401130d237d43c35ee9 100644 --- a/megatron/core/inference/modelopt_support/gpt/model_specs.py +++ b/megatron/core/inference/modelopt_support/gpt/model_specs.py @@ -1,63 +1,68 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules - - -# Use this spec for ModelOpt PTQ and TensorRT-LLM export -def get_gpt_layer_modelopt_spec( - num_experts: int = None, - moe_grouped_gemm: bool = False, - remap_te_layernorm: bool = False, - qk_layernorm: bool = False, -) -> ModuleSpec: - """Mix the native spec with TENorm. - - This is essentially the native local spec except for the layernorm implementation - is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex - has stopped supporting RMSNorm needed by llama. - """ - mlp = _get_mlp_module_spec( - use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False - ) - sharded_state_dict_keys_map = {} - if remap_te_layernorm: - if num_experts: - sharded_state_dict_keys_map = { - 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_' - } - else: - sharded_state_dict_keys_map = { - 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', - 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', - } - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=TENorm, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=RowParallelLinear, - q_layernorm=TENorm if qk_layernorm else IdentityOp, - k_layernorm=TENorm if qk_layernorm else IdentityOp, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=TENorm, - mlp=mlp, - mlp_bda=get_bias_dropout_add, - # Map TE-layernorm-fusion keys back - sharded_state_dict_keys_map=sharded_state_dict_keys_map, - ), - ) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import Optional + +from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + + +# Use this spec for ModelOpt PTQ and TensorRT-LLM export +def get_gpt_layer_modelopt_spec( + num_experts: Optional[int] = None, + local_core_attention: bool = False, + moe_grouped_gemm: bool = False, + remap_te_layernorm: bool = False, + qk_layernorm: bool = False, +) -> ModuleSpec: + """Mix the native spec with TENorm. + + This is essentially the native local spec except for the layernorm implementation + is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex + has stopped supporting RMSNorm needed by llama. + """ + core_attention = DotProductAttention if local_core_attention else TEDotProductAttention + mlp = get_mlp_module_spec( + use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False + ) + sharded_state_dict_keys_map = {} + if remap_te_layernorm: + if num_experts: + sharded_state_dict_keys_map = { + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_' + } + else: + sharded_state_dict_keys_map = { + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + } + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=core_attention, + linear_proj=RowParallelLinear, + q_layernorm=TENorm if qk_layernorm else IdentityOp, + k_layernorm=TENorm if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + # Map TE-layernorm-fusion keys back + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) diff --git a/megatron/core/inference/modelopt_support/mamba/__init__.py b/megatron/core/inference/modelopt_support/mamba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f3599e0d5c5bc87aa12cd22e19e6881207b151 --- /dev/null +++ b/megatron/core/inference/modelopt_support/mamba/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/inference/modelopt_support/mamba/model_specs.py b/megatron/core/inference/modelopt_support/mamba/model_specs.py new file mode 100644 index 0000000000000000000000000000000000000000..b2708689aa4d513dc0db2ae72dcded66a941b560 --- /dev/null +++ b/megatron/core/inference/modelopt_support/mamba/model_specs.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules +from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + + +# Use this spec for ModelOpt PTQ and TensorRT-LLM export +def get_mamba_stack_modelopt_spec( + local_core_attention: bool = False, remap_te_layernorm: bool = False +) -> ModuleSpec: + """Mix the native spec with TENorm. + + This is essentially the native local spec except for the layernorm implementation + is using TENorm from Transformer-Engine. + """ + mamba_state_dict_keys_map = {} + transformer_state_dict_keys_map = {} + if remap_te_layernorm: + mamba_state_dict_keys_map = {'norm.': 'mixer.in_proj.layer_norm_'} + transformer_state_dict_keys_map = { + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + } + + mamba_layer = ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + norm=TENorm, + mixer=ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=ColumnParallelLinear, out_proj=RowParallelLinear + ), + ), + mamba_bda=get_bias_dropout_add, + sharded_state_dict_keys_map=mamba_state_dict_keys_map, + ), + ) + + core_attention = DotProductAttention if local_core_attention else TEDotProductAttention + attention_layer = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=core_attention, + linear_proj=RowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + sharded_state_dict_keys_map=transformer_state_dict_keys_map, + ), + ) + + mlp_layer = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map=transformer_state_dict_keys_map, + ), + ) + + return ModuleSpec( + module=MambaStack, + submodules=MambaStackSubmodules( + mamba_layer=mamba_layer, attention_layer=attention_layer, mlp_layer=mlp_layer + ), + ) diff --git a/megatron/core/inference/sampling_params.py b/megatron/core/inference/sampling_params.py index 8ffcb6321dc13cb04574accaa0636987454222fa..d73a6124c487da068983c7840a2d5d9fe36ad745 100644 --- a/megatron/core/inference/sampling_params.py +++ b/megatron/core/inference/sampling_params.py @@ -1,35 +1,36 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from dataclasses import dataclass - - -@dataclass -class SamplingParams: - """Inference parameters sent along with the prompts. - This class contains request-level attributes that control the sampling techniques used when - generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level - inference attributes such as the maximum sequence length, and contains the KV cache. - - For an explanation of these parameters refer to this blog - https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and- - temperature-parameters-ed6a31313910 - """ - - temperature: float = 1.0 - top_k: int = 0 - top_p: float = 0.0 - return_log_probs: bool = False - num_tokens_to_generate: int = 30 - - def add_attributes(self, attribute_value_pair: dict): - """Utility to add more attributes to sampling params - - Use this method to pass in a custom dictionary to add more sampling parameter attributes. - c = SamplingParams - c.add_attributes({'min_length':4, 'eod_id':153}) - - Args: - attribute_value_pair (dict): A dictionary containing attributes as the key names and - their values as the values. - """ - for key, value in attribute_value_pair.items(): - setattr(self, key, value) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass + + +@dataclass +class SamplingParams: + """Inference parameters sent along with the prompts. + This class contains request-level attributes that control the sampling techniques used when + generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level + inference attributes such as the maximum sequence length, and contains the KV cache. + + For an explanation of these parameters refer to this blog + https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and- + temperature-parameters-ed6a31313910 + """ + + temperature: float = 1.0 + top_k: int = 0 + top_p: float = 0.0 + return_log_probs: bool = False + return_segments: bool = False # Whether to return individually detokenized tokens + num_tokens_to_generate: int = 30 + + def add_attributes(self, attribute_value_pair: dict): + """Utility to add more attributes to sampling params + + Use this method to pass in a custom dictionary to add more sampling parameter attributes. + c = SamplingParams + c.add_attributes({'min_length':4, 'eod_id':153}) + + Args: + attribute_value_pair (dict): A dictionary containing attributes as the key names and + their values as the values. + """ + for key, value in attribute_value_pair.items(): + setattr(self, key, value) diff --git a/megatron/core/inference/scheduler.py b/megatron/core/inference/scheduler.py index ef177232b42419fdae7284155701e403368561a6..d3afcb06ede12e182aa1156b9454ca3f19af8f8b 100644 --- a/megatron/core/inference/scheduler.py +++ b/megatron/core/inference/scheduler.py @@ -1,127 +1,175 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import time -import typing -from collections import OrderedDict -from typing import Dict - -import torch - -from megatron.core.inference.inference_request import InferenceRequest, Status -from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.utils import Counter - - -class Scheduler: - """Scheduler for handling requests to inference engine - - This class is responsible for handing of all the incomign requests - - Args: - max_batch_size (int): The max batch size that we can pass to the - inference engine at a time. - """ - - def __init__(self, max_batch_size: int): - self.max_batch_size = max_batch_size - self.active_request_pool: Dict[int, InferenceRequest] = OrderedDict() - self.waiting_request_pool: Dict[int, InferenceRequest] = OrderedDict() - self.completed_request_pool: Dict[int, InferenceRequest] = OrderedDict() - self.request_counter = Counter() - - def add_request( - self, - prompt: str, - prompt_tokens: torch.Tensor, - encoder_prompt: str = None, - inference_parameters: SamplingParams = None, - arrival_time: float = None, - ): - """Add an incoming request - - This method will add the request to either the active pool or the waiting pool - depending on the batch size. - - Args: - prompt (str): Input prompt string - prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized - encoder_prompt (str): Encoder input string - inference_parameters (SamplingParams): The inference parameters - arrival_time (float, optional): The incoming request time. Defaults to None. - """ - request_id = str(next(self.request_counter)) - - if arrival_time is None: - arrival_time = time.time() - - status = ( - Status.ACTIVE_BUT_NOT_GENERATING_TOKENS - if len(self.active_request_pool) < self.max_batch_size - else Status.WAITING_IN_QUEUE - ) - - inference_request = InferenceRequest( - request_id=request_id, - prompt=prompt, - inference_parameters=inference_parameters, - arrival_time=arrival_time, - prompt_tokens=prompt_tokens, - status=status, - encoder_prompt=encoder_prompt, - ) - - if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS: - self.active_request_pool[request_id] = inference_request - else: - self.waiting_request_pool[request_id] = inference_request - - def have_requests_pending(self) -> bool: - """Method to check if there are requests pending - - This method returns False only when there are no active requests or waiting requests. - """ - num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool) - return num_requests_pending > 0 - - def add_earliest_waiting_request_to_active_pool(self): - """Utility to add the waiting request to active pool - - This method will add the earliest request (FIFO) that is in the waiting request - pool to the active request pool. - """ - assert ( - len(self.active_request_pool) < self.max_batch_size - ), "Active request pool is already full. Cant add any more requests" - if len(self.waiting_request_pool) > 0: - (earliest_waiting_request_request_id, earliest_waiting_request) = ( - self.waiting_request_pool.popitem(last=False) - ) - earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS - self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request - - def update_requests_pools(self, result_dict: typing.OrderedDict[int, InferenceRequest] = None): - """Update request pool status - - This method will full up the active request pool, if it has less than max batch size - elements from the waiting request pool. - If provided with a request dict, it will put the completed requests into the completed - request pool and add waiting request into active pool. - - Args: - result (typing.OrderedDict[int, InferenceRequest], optional): The result returned - by the engine. A dictionary with keys as the request ids, and values as the - requests. Defaults to None - """ - for result_request_id in list(result_dict.keys()): - active_request = self.active_request_pool[result_request_id] - - # If a request has completed put it into the completed request pool. - if active_request.status == Status.COMPLETED: - completed_request = self.active_request_pool.pop(result_request_id) - self.completed_request_pool[result_request_id] = completed_request - - # If the active request pool is not full, add waiting requests in FIFO order - while ( - len(self.active_request_pool) < self.max_batch_size - and len(self.waiting_request_pool) > 0 - ): - self.add_earliest_waiting_request_to_active_pool() +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import functools +import time +import typing +from collections import OrderedDict +from typing import Dict, Optional, Type, Union + +import torch + +from megatron.core.inference.async_stream import AsyncStream +from megatron.core.inference.inference_request import InferenceRequest, Status +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.utils import Counter + + +class Scheduler: + """Scheduler for handling requests to inference engine + + This class is responsible for handing of all the incomign requests + + Args: + max_batch_size (int): The max batch size that we can pass to the + inference engine at a time. + request_type (InferenceRequest): The class to use for instantiating new requests. + """ + + def __init__(self, max_batch_size): + self.max_batch_size = max_batch_size + self.requests: Dict[str, InferenceRequest] = OrderedDict() + self.streams: Dict[str, AsyncStream] = OrderedDict() + self.active_request_pool: Dict[str, InferenceRequest] = OrderedDict() + self.waiting_request_pool: Dict[str, InferenceRequest] = OrderedDict() + self.completed_request_pool: Dict[str, InferenceRequest] = OrderedDict() + self.request_counter = Counter() + + def get_new_request_id(self) -> str: + """Gets a new request id""" + request_id = str(next(self.request_counter)) + return request_id + + def add_request( + self, + prompt: Optional[str] = None, + prompt_tokens: Optional[torch.Tensor] = None, + encoder_prompt: Optional[str] = None, + inference_parameters: Optional[SamplingParams] = None, + arrival_time: Optional[float] = None, + streaming: bool = False, + inference_request: Optional[InferenceRequest] = None, + ) -> str: + """Add an incoming request + + This method will add the request to either the active pool or the waiting pool + depending on the batch size. + + Args: + prompt (str): Input prompt string + prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized + encoder_prompt (str): Encoder input string + inference_parameters (SamplingParams): The inference parameters + arrival_time (float, optional): The incoming request time. Defaults to None. + streaming (bool, optional): Whether to asynchronously stream tokens for this request. + inference_request (InferenceRequest, optional): A fully constructed request. + Defaults to None. + + Returns: + The request_id for the new request. + """ + status = ( + Status.ACTIVE_BUT_NOT_GENERATING_TOKENS + if len(self.active_request_pool) < self.max_batch_size + else Status.WAITING_IN_QUEUE + ) + + if inference_request is None: + assert prompt is not None + assert prompt_tokens is not None + + request_id = self.get_new_request_id() + + if arrival_time is None: + arrival_time = time.time() + + inference_request = InferenceRequest( + request_id=request_id, + prompt=prompt, + inference_parameters=inference_parameters, + arrival_time=arrival_time, + prompt_tokens=prompt_tokens, + status=status, + encoder_prompt=encoder_prompt, + ) + else: + request_id = inference_request.request_id + inference_request.status = status + if inference_request.arrival_time is None: + inference_request.arrival_time = time.time() + + self.requests[request_id] = inference_request + + if streaming: + abort_request = functools.partial(self.abort_request, request_id=request_id) + self.streams[request_id] = AsyncStream(request_id, abort_request) + + if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS: + self.active_request_pool[request_id] = inference_request + else: + self.waiting_request_pool[request_id] = inference_request + + return request_id + + def have_requests_pending(self) -> bool: + """Method to check if there are requests pending + + This method returns False only when there are no active requests or waiting requests. + """ + num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool) + return num_requests_pending > 0 + + def add_earliest_waiting_request_to_active_pool(self): + """Utility to add the waiting request to active pool + + This method will add the earliest request (FIFO) that is in the waiting request + pool to the active request pool. + """ + assert ( + len(self.active_request_pool) < self.max_batch_size + ), "Active request pool is already full. Cant add any more requests" + if len(self.waiting_request_pool) > 0: + (earliest_waiting_request_request_id, earliest_waiting_request) = ( + self.waiting_request_pool.popitem(last=False) + ) + earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS + self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request + + def update_requests_pools( + self, result_dict: Optional[typing.OrderedDict[str, InferenceRequest]] = None + ): + """Update request pool status + + This method will full up the active request pool, if it has less than max batch size + elements from the waiting request pool. + If provided with a request dict, it will put the completed requests into the completed + request pool and add waiting request into active pool. + + Args: + result (typing.OrderedDict[str, InferenceRequest], optional): The result returned + by the engine. A dictionary with keys as the request ids, and values as the + requests. Defaults to None + """ + for result_request_id in list(result_dict.keys()): + active_request = self.active_request_pool[result_request_id] + + # If a request has completed put it into the completed request pool. + if active_request.status == Status.COMPLETED: + completed_request = self.active_request_pool.pop(result_request_id) + self.completed_request_pool[result_request_id] = completed_request + + # If the active request pool is not full, add waiting requests in FIFO order + while ( + len(self.active_request_pool) < self.max_batch_size + and len(self.waiting_request_pool) > 0 + ): + self.add_earliest_waiting_request_to_active_pool() + + def abort_request( + self, + request_id: str, + *, + exception: Optional[Union[BaseException, Type[BaseException]]] = None + ): + """Cancels the given request""" + stream = self.streams.get(request_id, None) + if stream is not None: + stream.finish(exception=exception) diff --git a/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py index 0c2a41be44a10228dce66541e0f93559e691f288..f50ba900bcbd72ff0336b62509df1501c3cd215d 100644 --- a/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py @@ -1,35 +1,38 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from typing import OrderedDict - -import torch - -from megatron.core.inference.inference_request import InferenceRequest -from megatron.core.inference.text_generation_controllers.text_generation_controller import ( - TextGenerationController, -) - - -class EncoderDecoderTextGenerationController(TextGenerationController): - """The text generation controller for encoder-decoder architecture - - This class inherits from TextGenerationController, adding features - relating to encoder input encoder_prompt - - """ - - def prep_model_for_inference( - self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest] - ): - """Preparing batch for inference, using respective wrapper's prep_model_for_inference method - - Args: - prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] - active_requests (OrderedDict[int, InferenceRequest]): The input active requests - """ - encoder_prompts = list( - map(lambda request: request.encoder_prompt, active_requests.values()) - ) - - self.inference_wrapped_model.prep_model_for_inference( - prompts_tokens=prompts_tokens, encoder_prompts=encoder_prompts, tokenizer=self.tokenizer - ) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import Any, Dict, OrderedDict + +import torch + +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) + + +class EncoderDecoderTextGenerationController(TextGenerationController): + """The text generation controller for encoder-decoder architecture + + This class inherits from TextGenerationController, adding features + relating to encoder input encoder_prompt + + """ + + def prep_inference_input( + self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest] + ) -> Dict[str, Any]: + """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] + active_requests (OrderedDict[str, InferenceRequest]): The input active requests + + Returns: + A dict of the inference input for the current batch. + """ + encoder_prompts = list( + map(lambda request: request.encoder_prompt, active_requests.values()) + ) + + return self.inference_wrapped_model.prep_inference_input( + prompts_tokens, encoder_prompts, tokenizer=self.tokenizer + ) diff --git a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py index f97df132493416b0f53c267b0c9088ef7f668a0d..54627c2c1f787c75d371a7048e0919216c26590f 100644 --- a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py @@ -1,5 +1,5 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import - TextGenerationController as SimpleTextGenerationController, -) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import + TextGenerationController as SimpleTextGenerationController, +) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index f15c819c43a5f824224b958e78d2359260a18640..f1a4ae45a6bc10f96c25c7ce664583933c2e05a9 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -1,400 +1,674 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from typing import List, OrderedDict, Tuple - -import torch -import torch.nn.functional as F - -from megatron.core import parallel_state -from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage -from megatron.core.inference.inference_request import InferenceRequest, Status -from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( - AbstractModelInferenceWrapper, -) -from megatron.core.inference.sampling_params import SamplingParams - - -class TextGenerationController: - """The text generation controller (the main sampling loop) - - This class tokenizes the input, runs inference, samples from logits, and detokenizes the output. - - Args: - inference_wrapped_model (AbstractModelInferenceWrapper): A model that - is wrapped using the specs given in the abstract_model_inference_wrapper.py - tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts - """ - - def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer): - self.inference_wrapped_model = inference_wrapped_model - self.tokenizer = tokenizer - - # For models without pipeline parallelism, is_first_stage and is_last_stage returns True - self.model_is_pipeline_parallel = not ( - parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() - ) - - def tokenize_prompt( - self, prompt: str, add_BOS: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Utility to tokenize the input prompts - - Args: - prompt (str): The input prompt - - Returns: - torch.Tensor: Returns the tokenized prompt - """ - prompt_tokens = self.tokenizer.tokenize(prompt) - - if add_BOS: - prompt_tokens = [self.tokenizer.bos] + prompt_tokens - - return prompt_tokens - - def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str: - """Detokenize the output generations - - Args: - prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt - tokens plus the generated tokens - - Returns: - str: The detokenized output - """ - tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist() - return self.tokenizer.detokenize(tokens) - - def sample_from_logits( - self, - last_token_logits: torch.Tensor, - sampling_params: SamplingParams = None, - vocab_size: int = None, - **kwargs - ) -> torch.Tensor: - """Samples the logits to generate outputs - - Given the logits of the last token, this function samples it - according to the parameters defined in sampling_params - and returns the samples - - Args: - last_token_logits (torch.Tensor): The last token logits. A tensor of - size [batch_size, vocab_size] - sampling_params (SamplingParams): The parameters to use for inference. - vocab_size (int): Obtained from the tokenizer. Defaults to None - - Returns: - torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements - """ - - if kwargs.get('common_inference_params'): - sampling_params = kwargs['common_inference_params'] - - top_p = sampling_params.top_p - top_k = sampling_params.top_k - temperature = sampling_params.temperature - - assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero' - assert top_p <= 1.0, 'top-p should be in (0,1]' - - def modify_logits_for_top_k_filtering(logits, top_k): - """Set the logits for none top-k values to -inf.""" - filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits.masked_fill_(filter_, float('-Inf')) - - def modify_logits_for_top_p_filtering(logits, top_p): - """Set the logits for none top-p values to -inf.""" - # First sort and calculate cumulative sum of probabilities. - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) - - # Filteration based on the cumulative sum. - filter_ = cumulative_probs > top_p - # This shift by 1 is weird and I cannot justify it. This existed - # in the original implementation: - # https://github.com/ari-holtzman/degen/blob/master/gen.py - # and I guess it is needed so keeping it for now. - filter_[:, 1:] = filter_[:, :-1].clone() - # Make sure we at least have one token to select from. - filter_[..., 0] = 0 - - # Fill in the filtered part - filter_ = filter_.scatter(1, sorted_indices, filter_) - logits.masked_fill_(filter_, float('-Inf')) - - # Greedy sampling - if top_k == 1: - sampled_logits = torch.argmax(last_token_logits, dim=-1) - else: - last_token_logits = last_token_logits.clone() - if temperature != 1.0: - last_token_logits.div_(temperature) - - if top_k > 1: - assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.' - if vocab_size: - assert top_k < vocab_size, 'top-k is larger than vocab size.' - modify_logits_for_top_k_filtering(last_token_logits, top_k) - - elif top_p > 0.0: - modify_logits_for_top_p_filtering(last_token_logits, top_p) - - # After filtering, we need to recalculate the distribution. - probabilities = last_token_logits.softmax(dim=-1) - sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1) - - # If vocab size is provided, make sure the samples are in in the range [0, vocab-size). - if vocab_size: - sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1)) - return sampled_logits - - def update_generation_status( - self, - updated_prompts_tokens: torch.Tensor, - generation_started: torch.Tensor, - current_context_end_position: int, - is_generation_done_tensor: torch.Tensor, - generated_sequence_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Checks which prompts have reached an end condition - - We check which prompts have reached an end condition and set the corresponding - flags of the is_generation_done_tensor to True. The generated sequence lengths - increase as we keep generating, until that prompts hits an end condition. The - generation_started tensor determines which prompts have started generating. - - Args: - updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest - generated tokens. A tensor of shape [batch_size, max_seq_len] - (i.e max_seq_len = max_prompt_len + tokens_to_generate) - generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True - indicates the prompt at that index has started generating tokens. - current_context_end_position (int): An integer indicating which position to - extract from the prompts tokens to get the latest generated tokens. - is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size]. - True indicates the prompt at that index has reached end condition. - generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size]. - Each value represents the generated sequence lengths for that prompt. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean - is_generation_done_tensor and the generated_sequence_lengths after updating it - """ - latest_samples = updated_prompts_tokens[:, current_context_end_position] - # Make sure we are checking eod criterion only for prompts that have started generating - # (i.e) We only look at the generated tokenns and not the input tokens. - reached_eod = (latest_samples == self.tokenizer.eod) & generation_started - is_generation_done_tensor = is_generation_done_tensor | reached_eod - # We increment generated sequence lengths when that prompt has not hit the - # EOD and generation has started - generated_sequence_lengths += ~is_generation_done_tensor & generation_started - - return is_generation_done_tensor, generated_sequence_lengths - - def pad_input_prompt_tokens( - self, - batch_prompt_tokens_list: List[List[int]], - max_prompt_length_in_batch: int, - num_tokens_to_generate: int, - ) -> torch.Tensor: - """Method to pad input prompts - - Given a list of prompts, pad them all to uniform length - - Args: - batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens - max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens - num_tokens_togenerate (int): The number of tokens to generate for each prompt - - Returns: - torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e) - max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate, - with extra indices for each tensor padded with mask id. - """ - max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate - - for prompt_tokens in batch_prompt_tokens_list: - padding_size = max_seq_len - len(prompt_tokens) - prompt_tokens.extend([self.tokenizer.eod] * padding_size) - - return torch.tensor(batch_prompt_tokens_list).cuda() - - def generate_output_tokens_dynamic_batch( - self, active_requests: OrderedDict[int, InferenceRequest] - ) -> OrderedDict[int, InferenceRequest]: - """Utility to generate the output tokens and probabilities for the prompts - - This utility generates the output tokens for a dynamic batch. It will run one forward step - at a time, and pass control back to the engine, which will update the request pool and call - this method again. - - Args: - active_requests (OrderedDict[int, InferenceRequest]): The input active requests. - - Returns: - OrderedDict[int, InferenceRequest]: The result for each of the incoming requests - after running one forward step. - """ - raise Exception("Not implemented yet") - - def generate_all_output_tokens_static_batch( - self, active_requests: OrderedDict[int, InferenceRequest] - ) -> OrderedDict[int, InferenceRequest]: - """Utility to generate the all the output tokens and probabilities for the prompts . - - This utility generates the output tokens for a static batch. It runs the forward steps till - all prompts complete generation, updates the status of these requests to completed, adds - the generated result and returns these requests - - Args: - active_requests (OrderedDict[int, InferenceRequest]): The input active requests. - - Returns: - OrderedDict[int, InferenceRequest]: The result for each of the incoming requests - """ - batch_prompt_tokens_list = list( - map(lambda request: request.prompt_tokens, active_requests.values()) - ) - prompt_lengths_in_batch = torch.tensor( - [len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list] - ).cuda() - max_prompt_length_in_batch = max(prompt_lengths_in_batch) - min_prompt_length_in_batch = min(prompt_lengths_in_batch) - - # For batch inference the inference params are the same for all request - sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters - - # max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate - batch_prompt_tokens = self.pad_input_prompt_tokens( - batch_prompt_tokens_list, - max_prompt_length_in_batch=max_prompt_length_in_batch, - num_tokens_to_generate=sampling_params.num_tokens_to_generate, - ) - batch_size, max_sequence_length = batch_prompt_tokens.shape - - # Pre allocate log probs tensor - output_log_probs = None - if sampling_params.return_log_probs: - output_log_probs = torch.empty( - (batch_size, max_sequence_length - 1), dtype=torch.float32 - ).cuda() - - # An array to check which of the prompts have reached end of generation condition - is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda() - - # An array to act as a counter to keep track of generated sequence lengths - generated_sequence_lengths = torch.zeros(batch_size).cuda() - - with torch.no_grad(): - - self.prep_model_for_inference( - prompts_tokens=batch_prompt_tokens, active_requests=active_requests - ) - - context_start_position = 0 - # Pick the context window that we need to pass through the network. - for context_end_position in range(min_prompt_length_in_batch, max_sequence_length): - - inference_input = self.inference_wrapped_model.get_batch_for_context_window( - context_start_position, context_end_position - ) - - # Returns the final logits of shape [batch_size, context_length, vocab_size] - # Note: This is returned in all TP ranks or last PP stage in PP models - logits = self.inference_wrapped_model.run_one_forward_step(inference_input) - if self.model_is_pipeline_parallel: - context_length = context_end_position - context_start_position - logits = broadcast_from_last_pipeline_stage( - [batch_size, context_length, self.tokenizer.vocab_size], - dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype, - tensor=logits, - ) - - # Indicates which of the input prompts have started generating tokens. - # A 1D boolean tensor with [batch_size] elements (i.e) The shortest - # prompts will start generating first and so on - generation_started = prompt_lengths_in_batch <= context_end_position - last_token_logits = logits[:, -1, :] - sampled_logits = self.sample_from_logits( - last_token_logits, sampling_params, self.tokenizer.vocab_size - ) - - # Substitute the sampled logits only for only the prompts that - # have started generating tokens - batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[ - generation_started - ] - - if sampling_params.return_log_probs: - log_probs = F.log_softmax(logits, dim=2) - indices = torch.unsqueeze( - batch_prompt_tokens[ - :, (context_start_position + 1) : (context_end_position + 1) - ], - 2, - ) - # Get the log probabilities for only the prompt tokens - output_log_probs[:, context_start_position:context_end_position] = torch.gather( - log_probs, 2, indices - ).squeeze(2) - - context_start_position = context_end_position - - # Check end of generation status for each tensor - # and update generated sequence lengths - (is_generation_done_tensor, generated_sequence_lengths) = ( - self.update_generation_status( - updated_prompts_tokens=batch_prompt_tokens, - generation_started=generation_started, - current_context_end_position=context_end_position, - is_generation_done_tensor=is_generation_done_tensor, - generated_sequence_lengths=generated_sequence_lengths, - ) - ) - # Boolean flag indicating if all prompts are finished - all_prompts_done = torch.all(is_generation_done_tensor) - if all_prompts_done: - break - - # Include all the generated tokens - batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)] - if sampling_params.return_log_probs: - output_log_probs = output_log_probs[:, :context_end_position] - - generated_sequence_lengths[ - generated_sequence_lengths > sampling_params.num_tokens_to_generate - ] = sampling_params.num_tokens_to_generate - - for idx, request in enumerate(active_requests.values()): - input_prompt_length = int(prompt_lengths_in_batch[idx]) - # Shorter prompts might have generated more than required tokens. So we trim them down - required_sequence_length = int( - min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate) - ) - # Extract only the generated tokens - required_result_tokens = batch_prompt_tokens_with_generations[ - idx, input_prompt_length : (input_prompt_length + required_sequence_length) - ] - - request.generated_length = required_sequence_length - request.generated_tokens = required_result_tokens - request.generated_log_probs = ( - None - if output_log_probs is None - else output_log_probs[idx, input_prompt_length:required_sequence_length] - ) - request.status = Status.COMPLETED - request.generated_text = self.detokenize_generations(required_result_tokens) - - return active_requests - - def prep_model_for_inference( - self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest] - ): - """Preparing batch for inference, using respective wrapper's prep_model_for_inference method - - Args: - prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] - active_requests (OrderedDict[int, InferenceRequest]): The input active requests - """ - self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=prompts_tokens) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import concurrent +import copy +import functools +from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union + +import torch +import torch.nn.functional as F + +from megatron.core import parallel_state +from megatron.core.inference.async_stream import AsyncStream +from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage +from megatron.core.inference.inference_request import InferenceRequest, Status +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.transformer.cuda_graphs import create_cudagraphs +from megatron.core.utils import get_model_config + + +class TextGenerationController: + """The text generation controller (the main sampling loop) + + This class tokenizes the input, runs inference, samples from logits, and detokenizes the output. + + Args: + inference_wrapped_model (AbstractModelInferenceWrapper): A model that + is wrapped using the specs given in the abstract_model_inference_wrapper.py + tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts + """ + + def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer): + self.inference_wrapped_model = inference_wrapped_model + self.tokenizer = tokenizer + + # For models without pipeline parallelism, is_first_stage and is_last_stage returns True + self.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + def tokenize_prompt( + self, prompt: str, add_BOS: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Utility to tokenize the input prompts + + Args: + prompt (str): The input prompt + + Returns: + torch.Tensor: Returns the tokenized prompt + """ + prompt_tokens = self.tokenizer.tokenize(prompt) + + if add_BOS: + prompt_tokens = [self.tokenizer.bos] + prompt_tokens + + return prompt_tokens + + def detokenize_generations( + self, + tokens_gpu_tensor: torch.Tensor, + lengths_gpu_tensor: torch.Tensor, + detokenize_segments: bool, + ) -> tuple[str, Optional[List[List[str]]]]: + """Detokenize the generated tokens. + + Args: + tokens_gpu_tensor (torch.Tensor): Tensor containing the tokens + lengths_gpu_tensor (torch.Tensor): Tensor containing the lengths of each sequence + detokenize_segments (bool): If True, returns individually detokenized tokens. If False, + returns None as second element. Helpful for understanding per-token boundaries in + generated text. + + Returns: + tuple[str, List[str] | None]: A tuple containing: + - str: The complete detokenized text + - List[str] | None: List of segmented tokens if detokenize_segments is True, else None + """ + # TODO(helenn): Unify with `detokenize_generations` from legacy textgen path + + if not detokenize_segments: + tokens = tokens_gpu_tensor.cpu().numpy().tolist() + return self.tokenizer.detokenize(tokens), None + + prompts_plus_generations: List[str] = [] + prompts_plus_generations_segments: List[List[str]] = [] + + tokens_gpu_tensor = torch.unsqueeze(tokens_gpu_tensor, 0) + tokens = tokens_gpu_tensor.cpu().numpy().tolist() + lengths = lengths_gpu_tensor.cpu().numpy().tolist() + + for sequence_tokens, length in zip(tokens, lengths): + sequence_tokens = sequence_tokens[:length] + detok_str = self.tokenizer.detokenize(sequence_tokens) + prompts_plus_generations.append(detok_str) + offsets = self.tokenizer.offsets(sequence_tokens, detok_str) + words = [ + detok_str[start:end] for start, end in zip(offsets, offsets[1:] + [len(detok_str)]) + ] + + prompts_plus_generations_segments.append(words) + + text = self.tokenizer.detokenize(tokens[0]) + + return text, prompts_plus_generations_segments + + def sample_from_logits( + self, + last_token_logits: torch.Tensor, + sampling_params: Optional[SamplingParams] = None, + vocab_size: Optional[int] = None, + **kwargs, + ) -> torch.Tensor: + """Samples the logits to generate outputs + + Given the logits of the last token, this function samples it + according to the parameters defined in sampling_params + and returns the samples + + Args: + last_token_logits (torch.Tensor): The last token logits. A tensor of + size [batch_size, vocab_size] + sampling_params (SamplingParams): The parameters to use for inference. + vocab_size (int): Obtained from the tokenizer. Defaults to None + + Returns: + torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements + """ + + if kwargs.get('common_inference_params'): + sampling_params = kwargs['common_inference_params'] + + top_p = sampling_params.top_p + top_k = sampling_params.top_k + temperature = sampling_params.temperature + + assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero' + assert top_p <= 1.0, 'top-p should be in (0,1]' + + def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(filter_, float('-Inf')) + + def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Filteration based on the cumulative sum. + filter_ = cumulative_probs > top_p + # This shift by 1 is weird and I cannot justify it. This existed + # in the original implementation: + # https://github.com/ari-holtzman/degen/blob/master/gen.py + # and I guess it is needed so keeping it for now. + filter_[:, 1:] = filter_[:, :-1].clone() + # Make sure we at least have one token to select from. + filter_[..., 0] = 0 + + # Fill in the filtered part + filter_ = filter_.scatter(1, sorted_indices, filter_) + logits.masked_fill_(filter_, float('-Inf')) + + # Greedy sampling + if top_k == 1: + sampled_logits = torch.argmax(last_token_logits, dim=-1) + else: + last_token_logits = last_token_logits.clone() + if temperature != 1.0: + last_token_logits.div_(temperature) + + if top_k > 1: + assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.' + if vocab_size: + assert top_k < vocab_size, 'top-k is larger than vocab size.' + modify_logits_for_top_k_filtering(last_token_logits, top_k) + + elif top_p > 0.0: + modify_logits_for_top_p_filtering(last_token_logits, top_p) + + # After filtering, we need to recalculate the distribution. + probabilities = last_token_logits.softmax(dim=-1) + sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1) + + # If vocab size is provided, make sure the samples are in in the range [0, vocab-size). + if vocab_size: + sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1)) + return sampled_logits + + def update_generation_status( + self, + updated_prompts_tokens: torch.Tensor, + generation_started: torch.Tensor, + current_context_end_position: int, + is_generation_done_tensor: torch.Tensor, + generated_sequence_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Checks which prompts have reached an end condition + + We check which prompts have reached an end condition and set the corresponding + flags of the is_generation_done_tensor to True. The generated sequence lengths + increase as we keep generating, until that prompts hits an end condition. The + generation_started tensor determines which prompts have started generating. + + Args: + updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest + generated tokens. A tensor of shape [batch_size, max_seq_len] + (i.e max_seq_len = max_prompt_len + tokens_to_generate) + generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True + indicates the prompt at that index has started generating tokens. + current_context_end_position (int): An integer indicating which position to + extract from the prompts tokens to get the latest generated tokens. + is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size]. + True indicates the prompt at that index has reached end condition. + generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size]. + Each value represents the generated sequence lengths for that prompt. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Returns the boolean + is_generation_done_tensor and the generated_sequence_lengths after updating it + """ + latest_samples = updated_prompts_tokens[:, current_context_end_position] + # Make sure we are checking eod criterion only for prompts that have started generating + # (i.e) We only look at the generated tokenns and not the input tokens. + reached_eod = (latest_samples == self.tokenizer.eod) & generation_started + is_generation_done_tensor = is_generation_done_tensor | reached_eod + # We increment generated sequence lengths when that prompt has not hit the + # EOD and generation has started + generated_sequence_lengths += ~is_generation_done_tensor & generation_started + + return is_generation_done_tensor, generated_sequence_lengths.int() + + def pad_input_prompt_tokens( + self, + batch_prompt_tokens_list: List[List[int]], + max_prompt_length_in_batch: int, + num_tokens_to_generate: int, + ) -> torch.Tensor: + """Method to pad input prompts + + Given a list of prompts, pad them all to uniform length + + Args: + batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens + max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens + num_tokens_togenerate (int): The number of tokens to generate for each prompt + + Returns: + torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e) + max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate, + """ + max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate + + for prompt_tokens in batch_prompt_tokens_list: + padding_size = max_seq_len - len(prompt_tokens) + prompt_tokens.extend([self.tokenizer.eod] * padding_size) + + return torch.tensor(batch_prompt_tokens_list, device=torch.cuda.current_device()) + + def generate_output_tokens_dynamic_batch( + self, active_requests: OrderedDict[str, InferenceRequest] + ) -> OrderedDict[str, InferenceRequest]: + """Utility to generate the output tokens and probabilities for the prompts + + This utility generates the output tokens for a dynamic batch. It will run one forward step + at a time, and pass control back to the engine, which will update the request pool and call + this method again. + + Args: + active_requests (OrderedDict[str, InferenceRequest]): The input active requests. + + Returns: + OrderedDict[str, InferenceRequest]: The result for each of the incoming requests + after running one forward step. + """ + raise Exception("Not implemented yet") + + def generate_all_output_tokens_static_batch( + self, + active_requests: OrderedDict[str, InferenceRequest], + active_streams: Optional[OrderedDict[str, AsyncStream]] = None, + ) -> OrderedDict[str, InferenceRequest]: + """Utility to generate the all the output tokens and probabilities for the prompts . + + This utility generates the output tokens for a static batch. It runs the forward steps till + all prompts complete generation, updates the status of these requests to completed, adds + the generated result and returns these requests + + Args: + active_requests (OrderedDict[str, InferenceRequest]): The input active requests. + + Returns: + OrderedDict[str, InferenceRequest]: The result for each of the incoming requests + """ + assert all(request.prompt_tokens is not None for request in active_requests.values()) + + # Perform a deep copy so that the request prompt tokens do not get modified. + batch_prompt_tokens_list: List[List[int]] = list( + map( + lambda request: copy.deepcopy(request.prompt_tokens), # type: ignore[arg-type] + active_requests.values(), + ) + ) + prompt_lengths_in_batch = torch.tensor( + [len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list], + device=torch.cuda.current_device(), + ) + max_prompt_length_in_batch = max(prompt_lengths_in_batch) + min_prompt_length_in_batch = min(prompt_lengths_in_batch) + + # For batch inference the inference params are the same for all request + sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters + + # max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate + batch_prompt_tokens = self.pad_input_prompt_tokens( + batch_prompt_tokens_list, + max_prompt_length_in_batch=max_prompt_length_in_batch, + num_tokens_to_generate=sampling_params.num_tokens_to_generate, + ) + batch_size, max_sequence_length = batch_prompt_tokens.shape + + # Verify that output sequence length is within configured limit + # TODO(ksanthanam): Raise TokenOverflowError once !2518 is merged + inference_max_sequence_length = ( + self.inference_wrapped_model.inference_wrapper_config.inference_max_seq_length + ) + assert max_sequence_length <= inference_max_sequence_length, ( + f"Maximum allowed sequence length was set to {inference_max_sequence_length} tokens " + f"but requested generation of {max_sequence_length} tokens" + ) + + # Pre allocate log probs tensor + output_log_probs = None + if sampling_params.return_log_probs: + output_log_probs = torch.empty( + (batch_size, max_sequence_length - 1), + dtype=torch.float32, + device=torch.cuda.current_device(), + ) + + # An array to check which of the prompts have reached end of generation condition + is_generation_done_tensor = torch.zeros( + batch_size, dtype=torch.bool, device=torch.cuda.current_device() + ) + + # An array to act as a counter to keep track of generated sequence lengths + generated_sequence_lengths = torch.zeros( + batch_size, device=torch.cuda.current_device() + ).cuda() + + # Use padded vocab size because tokenizer vocab size might not include padding + # to nearest power of 2 + vocab_size = self.inference_wrapped_model.inference_wrapper_config.padded_vocab_size + + # Check whether CUDA graphs are enabled + enable_cuda_graph = get_model_config(self.inference_wrapped_model.model).enable_cuda_graph + + streaming_enabled = active_streams is not None and len(active_streams) > 0 + if streaming_enabled: + # Start a separate thread for streaming tokens to avoid blocking the + # main computation + streaming_idx: List[int] = [ + i + for (i, request_id) in enumerate(active_requests.keys()) + if request_id in active_streams + ] + streaming_request_ids: List[str] = list(active_streams.keys()) + streams: List[AsyncStream] = list(active_streams.values()) + streaming_requests: List[InferenceRequest] = [ + active_requests[request_id] for request_id in streaming_request_ids + ] + streaming_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + stream_tokens = functools.partial(self.stream_tokens, sampling_params) + + with torch.no_grad(): + + self.inference_wrapped_model.prep_model_for_inference( + prompts_tokens=batch_prompt_tokens + ) + + inference_input: Dict[str, Any] = self.prep_inference_input( + prompts_tokens=batch_prompt_tokens, active_requests=active_requests + ) + + assert ( + not self.inference_wrapped_model.inference_params.decode_mode + ), f"Generation must start in prefill mode" + + context_start_position = 0 + # Pick the context window that we need to pass through the network. + for context_end_position in range(min_prompt_length_in_batch, max_sequence_length): + + inference_input_for_context_window: Dict[str, Any] = ( + self.inference_wrapped_model.get_batch_for_context_window( + inference_input, context_start_position, context_end_position + ) + ) + + # Disable attention mask when using CUDA graphs for decode + if ( + enable_cuda_graph + and self.inference_wrapped_model.inference_params.decode_mode + and "attention_mask" in inference_input_for_context_window + ): + inference_input_for_context_window["attention_mask"] = None + + # Returns the final logits of shape [batch_size, context_length, vocab_size] + # Note: This is returned in all TP ranks or last PP stage in PP models + logits = self.inference_wrapped_model.run_one_forward_step( + inference_input_for_context_window + ) + + if enable_cuda_graph: + create_cudagraphs() + + if self.model_is_pipeline_parallel: + context_length = context_end_position - context_start_position + logits = broadcast_from_last_pipeline_stage( + [batch_size, context_length, vocab_size], + dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype, + tensor=logits, + ) + + # Indicates which of the input prompts have started generating tokens. + # A 1D boolean tensor with [batch_size] elements (i.e) The shortest + # prompts will start generating first and so on + generation_started = prompt_lengths_in_batch <= context_end_position + last_token_logits = logits[:, -1, :] + sampled_logits = self.sample_from_logits( + last_token_logits, sampling_params, vocab_size + ) + + # Substitute the sampled logits only for the prompts that + # have started generating tokens + batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[ + generation_started + ] + + if sampling_params.return_log_probs: + log_probs = F.log_softmax(logits, dim=2) + indices = torch.unsqueeze( + batch_prompt_tokens[ + :, (context_start_position + 1) : (context_end_position + 1) + ], + 2, + ) + # Get the log probabilities for only the prompt tokens + assert output_log_probs is not None + output_log_probs[:, context_start_position:context_end_position] = torch.gather( + log_probs, 2, indices + ).squeeze(2) + + context_start_position = context_end_position + + # Check end of generation status for each tensor + # and update generated sequence lengths + (is_generation_done_tensor, generated_sequence_lengths) = ( + self.update_generation_status( + updated_prompts_tokens=batch_prompt_tokens, + generation_started=generation_started, + current_context_end_position=context_end_position, + is_generation_done_tensor=is_generation_done_tensor, + generated_sequence_lengths=generated_sequence_lengths, + ) + ) + + # Stream intermediate outputs + if streaming_enabled: + streaming_executor.submit( + stream_tokens, + streaming_request_ids, + streaming_requests, + streams, + generation_started[streaming_idx].cpu(), + is_generation_done_tensor[streaming_idx].cpu(), + batch_prompt_tokens[streaming_idx].cpu(), + prompt_lengths_in_batch[streaming_idx].cpu(), + generated_sequence_lengths[streaming_idx].cpu(), + ( + output_log_probs[streaming_idx].cpu() + if output_log_probs is not None + else [None] * len(streaming_idx) + ), + ) + + # Boolean flag indicating if all prompts are finished + all_prompts_done = torch.all(is_generation_done_tensor) + if all_prompts_done: + break + + # Change to decode mode if all prefill is complete + if torch.all(generation_started): + self.inference_wrapped_model.inference_params.enable_decode_mode() + + # Close all streams + if streaming_enabled: + streaming_executor.shutdown() + for stream in streams: + stream.finish() + + # Include all the generated tokens + batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)] + if sampling_params.return_log_probs: + assert output_log_probs is not None + output_log_probs = output_log_probs[:, :context_end_position] + + generated_sequence_lengths[ + generated_sequence_lengths > sampling_params.num_tokens_to_generate + ] = sampling_params.num_tokens_to_generate + + for idx, request in enumerate(active_requests.values()): + input_prompt_length = int(prompt_lengths_in_batch[idx]) + # Shorter prompts might have generated more than required tokens. So we trim them down + required_sequence_length = int( + min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate) + ) + # Extract only the generated tokens + required_result_tokens = batch_prompt_tokens_with_generations[ + idx, input_prompt_length : (input_prompt_length + required_sequence_length) + ] + generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32) + request.generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32) + request.generated_length = required_sequence_length + request.generated_tokens = required_result_tokens + + request.prompt_log_probs = ( + None + if output_log_probs is None + else output_log_probs[idx, :input_prompt_length].cpu().numpy().tolist() + ) + + request.generated_log_probs = ( + None + if output_log_probs is None + else output_log_probs[ + idx, + input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1), + ] + .cpu() + .numpy() + .tolist() + ) + request.status = Status.COMPLETED + + text, segments = self.detokenize_generations( + batch_prompt_tokens_with_generations[idx], + input_prompt_length + generated_sequence_lengths, + sampling_params.return_segments, + ) + request.text = text # Inference server returns prompts & generations together + if sampling_params.return_segments: + request.segments = segments[0] + request.generated_text = text[len(request.prompt) :] + return active_requests + + def prep_inference_input( + self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest] + ) -> Dict[str, Any]: + """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] + active_requests (OrderedDict[str, InferenceRequest]): The input active requests + + Returns: + A dict of the inference input for the current batch. + """ + return self.inference_wrapped_model.prep_inference_input(prompts_tokens) + + def stream_tokens( + self, + sampling_params: SamplingParams, + request_ids: List[str], + requests: List[InferenceRequest], + streams: List[AsyncStream], + generation_started: List[bool], + is_generation_done: List[bool], + tokens: torch.Tensor, + prompt_lengths: List[int], + generated_lengths: List[int], + output_log_probs: Union[torch.Tensor, None], + ): + """Asynchronously streams tokens for the given requests. + + Args: + sampling_params (SamplingParams): The sampling parameters. + request_ids (List[str]): The request IDs. + request (List[InferenceRequest]): The requests. + stream (List[AsyncStream]): The streams over which to send tokens. + generation_started (List[bool]): Whether the decode step has started. + is_generation_done (List[bool]): Whether generation has completed. + tokens (torch.Tensor): The tokens for this request. + prompt_lengths (List[int]): The number of prompt tokens for each request. + generated_lengths (List[int]): The number of output tokens for each request. + output_log_probs (torch.Tensor, optional): The log probs for each request. + """ + + def stream_token( + request_id: str, + request: InferenceRequest, + stream: AsyncStream, + generation_started: bool, + is_generation_done: bool, + tokens: torch.Tensor, + prompt_length: int, + generated_length: int, + output_log_probs: Union[torch.Tensor, None], + ): + """Asynchronously streams a token for the given request.""" + + if not generation_started or stream.finished: + return + + num_tokens_to_generate = sampling_params.num_tokens_to_generate + return_segments = sampling_params.return_segments + detokenize_streaming_text = not getattr( + sampling_params, "no_detokenize_streaming_text", False + ) + + generated_tokens = tokens[prompt_length : prompt_length + generated_length] + + if detokenize_streaming_text: + generated_text, generated_segments = self.detokenize_generations( + generated_tokens, prompt_length + generated_length, return_segments + ) + else: + generated_text = "" + generated_segments = [] + + if output_log_probs is not None: + generated_log_probs = ( + output_log_probs[prompt_length - 1 : prompt_length + generated_length - 1] + .cpu() + .numpy() + .tolist() + ) + else: + generated_log_probs = None + + stream.put( + InferenceRequest( + request_id=request_id, + prompt=request.prompt, + inference_parameters=request.inference_parameters, + prompt_tokens=request.prompt_tokens, + arrival_time=request.arrival_time, + status=request.status, + encoder_prompt=request.encoder_prompt, + generated_text=generated_text, + generated_segments=generated_segments, + generated_tokens=generated_tokens, + generated_log_probs=generated_log_probs, + generated_length=generated_length, + ) + ) + + if is_generation_done or generated_length == num_tokens_to_generate: + stream.finish() + + ret = map( + stream_token, + request_ids, + requests, + streams, + generation_started, + is_generation_done, + tokens, + prompt_lengths, + generated_lengths, + output_log_probs, + ) + list(ret) diff --git a/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..1d92947adf18d701068c8081e2a38417b7e4b52a --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import OrderedDict + +import torch + +from megatron.core.inference.inference_request import InferenceRequest, VLMInferenceRequest +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) + + +class VLMTextGenerationController(TextGenerationController): + """The text generation controller for VLMs""" + + def prep_inference_input( + self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest] + ): + """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long + + Currently only supports batch size 1 inference. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] + active_requests (OrderedDict[str, InferenceRequest]): The input active requests + """ + assert len(active_requests) == 1, f"VLM inference currently only supports batch size 1" + + request = list(active_requests.values())[0] + + assert isinstance( + request, VLMInferenceRequest + ), f"Found inference request of type {type(request)}, expected VLMInferenceRequest" + + return self.inference_wrapped_model.prep_inference_input( + prompts_tokens, + request.num_img_embeddings_per_tile, + request.imgs, + request.num_tiles, + request.decoder_seq_length, + ) diff --git a/megatron/core/inference_params.py b/megatron/core/inference_params.py index 0db49e3115af66273dfa4052a0929131ab06c679..846ceb73a533303fc5347fd955409cd0f0f409ec 100644 --- a/megatron/core/inference_params.py +++ b/megatron/core/inference_params.py @@ -1,31 +1,100 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -class InferenceParams: - """Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference.""" - - def __init__(self, max_batch_size, max_sequence_length): - self.max_sequence_length = max_sequence_length - self.max_batch_size = max_batch_size - self.sequence_len_offset = 0 - self.batch_size_offset = 0 - self.key_value_memory_dict = {} - - def swap_key_value_dict(self, batch_idx): - "swap between batches" - if len(self.key_value_memory_dict) == 0: - raise ValueError("should not swap when dict in empty") - - for layer_number in self.key_value_memory_dict.keys(): - inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] - assert ( - len(batch_idx) == inference_key_memory.shape[1] - ) # make sure batch size is the same - new_inference_key_memory = inference_key_memory[:, batch_idx] - new_inference_value_memory = inference_value_memory[:, batch_idx] - self.key_value_memory_dict[layer_number] = ( - new_inference_key_memory, - new_inference_value_memory, - ) - - def __str__(self): - return f"InferenceParams(max_seq_len = {self.max_sequence_length}, max_batch_size = {self.max_batch_size}, sequence_len_offset = {self.sequence_len_offset}, batch_size_offset = {self.batch_size_offset}, key_value_memory_dict = {self.key_value_memory_dict.keys()})" +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + def __init__(self, max_batch_size, max_sequence_length): + self.max_sequence_length = max_sequence_length + self.max_batch_size = max_batch_size + self.current_batch_size = max_batch_size # Required for bookkeeping variable-sized batches + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.decode_mode = False + self.key_value_memory_dict = {} + self.decode_mode = False + + def swap_key_value_dict(self, batch_idx): + "swap between batches" + if len(self.key_value_memory_dict) == 0: + raise ValueError("should not swap when dict in empty") + + for layer_number in self.key_value_memory_dict.keys(): + inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] + assert ( + len(batch_idx) == inference_key_memory.shape[1] + ) # make sure batch size is the same + new_inference_key_memory = inference_key_memory[:, batch_idx] + new_inference_value_memory = inference_value_memory[:, batch_idx] + self.key_value_memory_dict[layer_number] = ( + new_inference_key_memory, + new_inference_value_memory, + ) + + def enable_prefill_mode(self): + """ + Indicates the generation loop is in the prefill phase (still processing + input prompt tokens). This should be enabled if the generation loop is + encoding prompt tokens for *any* request in a batch. + """ + self.decode_mode = False + + def enable_decode_mode(self): + """ + Indicates the generation loop is in the decode phase (generating new output + tokens). This should only be enabled if the generation loop has fully encoded + the prompts for *all* requests in a batch. + """ + self.decode_mode = True + + def reset(self): + """Resets the inference state for a new batch.""" + self.current_batch_size = self.max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.enable_prefill_mode() + + def __str__(self): + return ( + f"InferenceParams(max_seq_len = {self.max_sequence_length}, " + f"max_batch_size = {self.max_batch_size}, " + f"current_batch_size = {self.current_batch_size}, " + f"sequence_len_offset = {self.sequence_len_offset}, " + f"batch_size_offset = {self.batch_size_offset}, " + f"key_value_memory_dict = {self.key_value_memory_dict.keys()})" + f"decode_mode = {self.decode_mode}" + ) + + def __eq__(self, other): + + if not isinstance(other, InferenceParams): + return False + + # Check all attributes match + basic_attrs = [ + 'max_sequence_length', + 'max_batch_size', + 'current_batch_size', + 'sequence_len_offset', + 'batch_size_offset', + ] + + if not all(hasattr(other, attr) for attr in basic_attrs): + return False + + # Check dictionary keys match; i.e. the same number of layers are cached + if self.key_value_memory_dict.keys() != other.key_value_memory_dict.keys(): + return False + + # Check each tensor tuple in the dictionary + for key in self.key_value_memory_dict: + self_tensors = self.key_value_memory_dict[key] + other_tensors = other.key_value_memory_dict[key] + + # Compare each key, value tensor in the tuple + for self_tensor, other_tensor in zip(self_tensors, other_tensors): + if ( + self_tensor.data_ptr() != other_tensor.data_ptr() + or self_tensor.shape != other_tensor.shape + ): + return False + return True diff --git a/megatron/core/jit.py b/megatron/core/jit.py index c35c41b9fa226b928e7dc35d5dcec95f2b6a6c2c..5b1dfff3e7786af920e99bff9b3491793e5a0c91 100644 --- a/megatron/core/jit.py +++ b/megatron/core/jit.py @@ -7,18 +7,4 @@ from megatron.core.utils import is_torch_min_version jit_fuser = torch.jit.script # nvFuser is deprecated in PyTorch JIT starting from 2.2 if is_torch_min_version("2.2.0a0"): - jit_fuser = torch.compile(mode='max-autotune-no-cudagraphs') - -# Decorator to disable Torch Dynamo -# See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda func: func -if torch.__version__ >= "2": - import torch._dynamo - - if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable( - f, recursive=recursive - ) - else: - # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True - no_torch_dynamo = lambda recursive=True: torch._dynamo.disable + jit_fuser = torch.compile diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 46a03f6d6d0f7acbbe75e0d9fe9bcdaf5b4a5cee..cf6b20f8da3e80074e8a3a428a4477fdb98f4eff 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -1,387 +1,392 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from dataclasses import dataclass -from typing import Callable, ContextManager, Optional - -import torch - - -@dataclass -class ModelParallelConfig: - """Base configuration for Megatron Core - - The initialization function has an argument for each parameter. - """ - - ################### - # Model parallelism - ################### - tensor_model_parallel_size: int = 1 - """Intra-layer model parallelism. Splits tensors across GPU ranks.""" - - pipeline_model_parallel_size: int = 1 - """Inter-layer model parallelism. Splits transformer layers across GPU ranks.""" - - virtual_pipeline_model_parallel_size: Optional[int] = None - """Interleaved pipeline parallelism is used to improve performance by reducing the pipeline - bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks. - The number of virtual blocks per pipeline model parallel rank is the virtual model parallel - size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: - arxiv.org/pdf/2104.04473.pdf for more details. - """ - - sequence_parallel: bool = False - """Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms - and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models - (https://arxiv.org/abs/2205.05198) for more details. - """ - - context_parallel_size: int = 1 - """Splits network input along sequence dimension across GPU ranks.""" - - hierarchical_context_parallel_sizes: Optional[list[int]] = None - """Degrees of the hierarchical context parallelism. Users should provide a list to specify - the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains - groups of two levels, so the first value of the list indicates the group size of the a2a - communication type, and the second value indicates the group size of the p2p communication - type. - """ - - expert_model_parallel_size: int = 1 - """Distributes Moe Experts across sub data parallel dimension.""" - - expert_tensor_parallel_size: Optional[int] = None - """Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks.""" - - moe_extended_tp: bool = False - """NOTE: Deprecated from MCore v0.10. This flag is ignored. - Its functionality is replaced by expert_tensor_parallel_size. - """ - - ################### - # Initialization - ################### - perform_initialization: bool = True - """If true, weights are initialized. This option can be useful when you know you are going to - load values from a checkpoint. - """ - - use_cpu_initialization: bool = False - """When set to False, we initialize the weights directly on the GPU. CPU initialization is the - same regardless of tensor model parallelism, but GPU initialization is not. Transferring - weights from CPU to GPU can take a significant amount of time for large models. - """ - - ################### - # Training - ################### - fp16: bool = False - """If true, train with fp16 mixed precision training.""" - - bf16: bool = False - """If true, train with bf16 mixed precision training.""" - - params_dtype: torch.dtype = torch.float32 - """dtype used when intializing the weights.""" - - timers: Optional[Callable] = None - """Timers object to call for various timing functions. See megatron.core.timers.Timers""" - - finalize_model_grads_func: Optional[Callable] = None - """Function that finalizes gradients on all workers. Could include ensuring that grads are - all-reduced across data parallelism, pipeline parallelism, and sequence parallelism - dimensions. - """ - - grad_scale_func: Optional[Callable] = None - """If using loss scaling, this function should take the loss and return the scaled loss. If - None, no function is called on the loss. - """ - - no_sync_func: Optional[Callable] = None - """Function that creates a context that suppresses asynchronous data-parallel communication. If - the model is an instance of core.distributed.DistributedDataParallel, the default is to use - core.distributed.DistributedDataParallel.no_sync. - """ - - grad_sync_func: Optional[Callable] = None - """Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient - reduce-scatters). The function should take one argument: an iterable of parameters whose - gradients are to be synchronized. - """ - - param_sync_func: Optional[Callable] = None - """Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer - parameter all-gathers). The function should take one argument: an iterable of parameters to - be synchronized. - """ - - deterministic_mode: bool = False - """If true, code that has deterministic execution will be chosen. This usually - means slower execution, but is good for debugging and testing. Defaults to False.""" - - enable_autocast: bool = False - """If true runs the forward step function inside torch.autocast context.""" - - autocast_dtype: Optional[torch.dtype] = None - """dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype.""" - - num_microbatches_with_partial_activation_checkpoints: Optional[int] = None - """If int, set the number of microbatches where not all of the layers will be checkpointed and - recomputed. The rest of the microbatches within the window of maximum outstanding - microbatches will recompute all layers (either full recompute or selective recompute). If - None, the checkpoint and recompute will be left up to the forward_step function. - - """ - - ################### - # Optimizations - ################### - gradient_accumulation_fusion: bool = False - """If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension - fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install - APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" - --global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you - must turn off gradient accumulation fusion. - """ - - async_tensor_model_parallel_allreduce: bool = False - """NOTE: Deprecated. This flag is ignored.""" - - use_te_rng_tracker: bool = False - """If true, uses RNG state tracker in TransformerEngine if exists. - """ - - tp_comm_overlap: bool = False - """If true, allows overlapping of Linear layer execution with tensor parallel communication - collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever - possible during the forward and the backward pass. - """ - - tp_comm_bulk_wgrad: bool = True - """If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if - tp_comm_overlap is False. - """ - - tp_comm_bulk_dgrad: bool = True - """If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if - tp_comm_overlap is False. - """ - - tp_comm_overlap_ag: bool = True - """If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather. - Don't care if tp_comm_overlap is False. - """ - - tp_comm_overlap_rs: bool = True - """If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter. - Don't care if tp_comm_overlap is False. - """ - - tp_comm_overlap_rs_dgrad: bool = False - """If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the - GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False. - """ - - tp_comm_split_ag: bool = True - """Deprecated from TransformerEngine v1.6.0. - If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather - splits. Don't care if tp_comm_overlap is False. - """ - - tp_comm_atomic_ag: bool = False - """Deprecated from TransformerEngine v1.6.0. - If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather - both done atomically. Don't care if tp_comm_overlap is False. - """ - - tp_comm_split_rs: bool = True - """Deprecated from TransformerEngine v1.6.0. - If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and - Reduce-Scatter splits. Don't care if tp_comm_overlap is False. - """ - - tp_comm_atomic_rs: bool = False - """Deprecated from TransformerEngine v1.6.0. - If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and - Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False. - """ - - cross_entropy_loss_fusion: bool = False - """If this is enabled, the fused cross entropy implementation would be used. - Defaults to False. - """ - - tp_comm_overlap_disable_qkv: bool = False - """ - If true, the AllGather -> Gemm overlap for QKV gets disabled - """ - - tp_comm_overlap_disable_fc1: bool = False - """ - If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled - """ - - tp_comm_bootstrap_backend: str = 'nccl' - """ - Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo' - """ - - ################### - # Pipeline Parallel - ################### - pipeline_dtype: torch.dtype = None - """dtype used in p2p communication, usually params_dtype""" - - variable_seq_lengths: bool = False - """Support for variable sequence lengths across microbatches. Setting this communicates the size - of tensors during pipeline parallelism communication, because of this extra overhead it - should only be set if the sequence length varies by microbatch within a global batch. - """ - - overlap_p2p_comm: bool = False - """When True some of the peer to peer communication for pipeline parallelism will overlap with - computation. Must be False if batch_p2p_comm is true. - """ - - batch_p2p_comm: bool = True - """Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if - overlap_p2p_comm is True. - """ - - batch_p2p_sync: bool = True - """When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in - older version of PyTorch. - """ - - use_ring_exchange_p2p: bool = False - """Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires - custom built torch with torch.distributed.ring_exchange. - """ - - deallocate_pipeline_outputs: bool = False - """If True, output data is deallocated after the tensor is sent to the next pipeline stage. - Helps with saving memory, does nothing when pipeline parallel is not used. - """ - - defer_embedding_wgrad_compute: bool = False - """If true, defers the embedding WGRAD GEMMs while pipeline flush is - taking place enabling us to hide pipeline flush latency. Defaults to False. - """ - - wgrad_deferral_limit: int = 0 - """This value tunes the number of micro-batches for which the embedding weight gradient compute - needs to be deferred to pipeline flush, this argument is invalid if - `defer_embedding_wgrad_compute` is False. - Defaults to 0, which means all micro-batches are deferred. - """ - - pipeline_model_parallel_split_rank: Optional[int] = None - """If int, rank where encoder and decoder should be split in cases where the model has both an - encoder and decoder (e.g., T5). Ignored if None. - """ - - overlap_p2p_comm_warmup_flush: bool = False - """If true, overlap communication and computation in warm up and flush phase. - Only valid when overlap_p2p_comm is True and batch_p2p_comm is False. - Defaults to False. - """ - - microbatch_group_size_per_vp_stage: Optional[int] = None - """This value specifies the number of micro-batches that are executed - at a time for a given virtual stage (both forward and backward). - Default (in __post_init__() method below) to pipeline_parallel_size - which specifies a depth-first schedule. - Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2, - num_microbatches = 4, we have - rank 0 | 0 1 0 1 2 3 2 3 - rank 1 | 0 1 0 1 2 3 2 3 - When microbatch_group_size_per_vp_stage=3, num_microbatches = 5, - we have - rank 0 | 0 1 2 0 1 2 3 4 3 4 - rank 1 | 0 1 2 0 1 2 3 4 3 4 - """ - - ################### - # CPU Offloading - ################### - cpu_offloading: bool = False - """When set to True, all the activations are offloaded to the CPU asynchronously.""" - - cpu_offloading_num_layers: int = 0 - """Tells the number of transformer layers for which activations has to be offloaded.""" - - _cpu_offloading_context: Optional[ContextManager] = ( - None - # Used for internal use only, not to be set by a user. - # TODO: Need to move to the 'right' place when possible. - ) - """For internal use only, do not set.""" - - cpu_offloading_activations: bool = True - """If True, offloads the activations to CPU.""" - - cpu_offloading_weights: bool = True - """If True, offloads the weights to CPU.""" - - ################### - # Timing - ################### - barrier_with_L1_time: bool = True - """If true, use barrier with level 1 time measurements. It is up to the user to make sure - calling barrier with their timers will not result in hangs. This can happen if for example - the user adds a level 1 timer that is not called by all ranks. - """ - - def __post_init__(self): - """Python dataclass method that is used to modify attributes after initialization. - See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more - details. - """ - if self.sequence_parallel: - if self.tensor_model_parallel_size <= 1: - raise ValueError("Can not use sequence paralllelism without tensor parallelism") - - if self.expert_tensor_parallel_size is None: - self.expert_tensor_parallel_size = self.tensor_model_parallel_size - - if self.pipeline_model_parallel_size > 1: - if self.pipeline_dtype is None: - raise ValueError( - "When using pipeline parallelism, pipeline_dtype must be specified" - ) - - if self.autocast_dtype is None: - self.autocast_dtype = self.params_dtype - - if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1: - raise ValueError( - "Cannot defer embedding wgrad compute when pipeline model parallel is not used" - ) - - if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion: - raise ValueError( - "Cannot defer embedding wgrad compute when gradient accumulation fusion is not used" - ) - - if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0: - raise ValueError( - "Wgrad deferral limit should be greater than or equal to 0 when it is enabled!" - ) - - if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1: - if self.sequence_parallel is False: - raise ValueError( - "When using expert parallelism and tensor parallelism, " - "sequence parallelism must be used" - ) - - if self.microbatch_group_size_per_vp_stage is None: - self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size - - if self.overlap_p2p_comm_warmup_flush: - if not self.overlap_p2p_comm or self.batch_p2p_comm: - raise ValueError( - "Pipeline parallel communication overlapping in warmup and flush is only " - "compatible with overlap_p2p_comm but not batch_p2p_comm." - ) +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass +from typing import Callable, ContextManager, Optional + +import torch + + +@dataclass +class ModelParallelConfig: + """Base configuration for Megatron Core + + The initialization function has an argument for each parameter. + """ + + ################### + # Model parallelism + ################### + tensor_model_parallel_size: int = 1 + """Intra-layer model parallelism. Splits tensors across GPU ranks.""" + + pipeline_model_parallel_comm_backend: Optional[str] = None + """Configuring backend option of pipeline parallel communication (e.g., nccl, ucc) + If None, the default backend will be used. + """ + + pipeline_model_parallel_size: int = 1 + """Inter-layer model parallelism. Splits transformer layers across GPU ranks.""" + + virtual_pipeline_model_parallel_size: Optional[int] = None + """Interleaved pipeline parallelism is used to improve performance by reducing the pipeline + bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks. + The number of virtual blocks per pipeline model parallel rank is the virtual model parallel + size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: + arxiv.org/pdf/2104.04473.pdf for more details. + """ + + sequence_parallel: bool = False + """Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms + and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models + (https://arxiv.org/abs/2205.05198) for more details. + """ + + context_parallel_size: int = 1 + """Splits network input along sequence dimension across GPU ranks.""" + + hierarchical_context_parallel_sizes: Optional[list[int]] = None + """Degrees of the hierarchical context parallelism. Users should provide a list to specify + the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains + groups of two levels, so the first value of the list indicates the group size of the a2a + communication type, and the second value indicates the group size of the p2p communication + type. + """ + + expert_model_parallel_size: int = 1 + """Distributes Moe Experts across sub data parallel dimension.""" + + expert_tensor_parallel_size: Optional[int] = None + """Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks.""" + + moe_extended_tp: bool = False + """NOTE: Deprecated from MCore v0.10. This flag is ignored. + Its functionality is replaced by expert_tensor_parallel_size. + """ + + ################### + # Initialization + ################### + perform_initialization: bool = True + """If true, weights are initialized. This option can be useful when you know you are going to + load values from a checkpoint. + """ + + use_cpu_initialization: bool = False + """When set to False, we initialize the weights directly on the GPU. CPU initialization is the + same regardless of tensor model parallelism, but GPU initialization is not. Transferring + weights from CPU to GPU can take a significant amount of time for large models. + """ + + ################### + # Training + ################### + fp16: bool = False + """If true, train with fp16 mixed precision training.""" + + bf16: bool = False + """If true, train with bf16 mixed precision training.""" + + params_dtype: torch.dtype = torch.float32 + """dtype used when intializing the weights.""" + + timers: Optional[Callable] = None + """Timers object to call for various timing functions. See megatron.core.timers.Timers""" + + finalize_model_grads_func: Optional[Callable] = None + """Function that finalizes gradients on all workers. Could include ensuring that grads are + all-reduced across data parallelism, pipeline parallelism, and sequence parallelism + dimensions. + """ + + grad_scale_func: Optional[Callable] = None + """If using loss scaling, this function should take the loss and return the scaled loss. If + None, no function is called on the loss. + """ + + no_sync_func: Optional[Callable] = None + """Function that creates a context that suppresses asynchronous data-parallel communication. If + the model is an instance of core.distributed.DistributedDataParallel, the default is to use + core.distributed.DistributedDataParallel.no_sync. + """ + + grad_sync_func: Optional[Callable] = None + """Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient + reduce-scatters). The function should take one argument: an iterable of parameters whose + gradients are to be synchronized. + """ + + param_sync_func: Optional[Callable] = None + """Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer + parameter all-gathers). The function should take one argument: an iterable of parameters to + be synchronized. + """ + + deterministic_mode: bool = False + """If true, code that has deterministic execution will be chosen. This usually + means slower execution, but is good for debugging and testing. Defaults to False.""" + + enable_autocast: bool = False + """If true runs the forward step function inside torch.autocast context.""" + + autocast_dtype: Optional[torch.dtype] = None + """dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype.""" + + num_microbatches_with_partial_activation_checkpoints: Optional[int] = None + """If int, set the number of microbatches where not all of the layers will be checkpointed and + recomputed. The rest of the microbatches within the window of maximum outstanding + microbatches will recompute all layers (either full recompute or selective recompute). If + None, the checkpoint and recompute will be left up to the forward_step function. + + """ + + ################### + # Optimizations + ################### + gradient_accumulation_fusion: bool = False + """If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension + fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install + APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" + --global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you + must turn off gradient accumulation fusion. + """ + + async_tensor_model_parallel_allreduce: bool = False + """NOTE: Deprecated. This flag is ignored.""" + + use_te_rng_tracker: bool = False + """If true, uses RNG state tracker in TransformerEngine if exists. + """ + + tp_comm_overlap: bool = False + """If true, allows overlapping of Linear layer execution with tensor parallel communication + collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever + possible during the forward and the backward pass. + """ + + tp_comm_bulk_wgrad: bool = True + """If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if + tp_comm_overlap is False. + """ + + tp_comm_bulk_dgrad: bool = True + """If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if + tp_comm_overlap is False. + """ + + tp_comm_overlap_ag: bool = True + """If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather. + Don't care if tp_comm_overlap is False. + """ + + tp_comm_overlap_rs: bool = True + """If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter. + Don't care if tp_comm_overlap is False. + """ + + tp_comm_overlap_rs_dgrad: bool = False + """If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the + GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False. + """ + + tp_comm_split_ag: bool = True + """Deprecated from TransformerEngine v1.6.0. + If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather + splits. Don't care if tp_comm_overlap is False. + """ + + tp_comm_atomic_ag: bool = False + """Deprecated from TransformerEngine v1.6.0. + If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather + both done atomically. Don't care if tp_comm_overlap is False. + """ + + tp_comm_split_rs: bool = True + """Deprecated from TransformerEngine v1.6.0. + If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and + Reduce-Scatter splits. Don't care if tp_comm_overlap is False. + """ + + tp_comm_atomic_rs: bool = False + """Deprecated from TransformerEngine v1.6.0. + If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and + Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False. + """ + + cross_entropy_loss_fusion: bool = False + """If this is enabled, the fused cross entropy implementation would be used. + Defaults to False. + """ + + tp_comm_overlap_disable_qkv: bool = False + """ + If true, the AllGather -> Gemm overlap for QKV gets disabled + """ + + tp_comm_overlap_disable_fc1: bool = False + """ + If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled + """ + + tp_comm_bootstrap_backend: str = 'nccl' + """ + Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo' + """ + + ################### + # Pipeline Parallel + ################### + pipeline_dtype: torch.dtype = None + """dtype used in p2p communication, usually params_dtype""" + + variable_seq_lengths: bool = False + """Support for variable sequence lengths across microbatches. Setting this communicates the size + of tensors during pipeline parallelism communication, because of this extra overhead it + should only be set if the sequence length varies by microbatch within a global batch. + """ + + overlap_p2p_comm: bool = False + """When True some of the peer to peer communication for pipeline parallelism will overlap with + computation. Must be False if batch_p2p_comm is true. + """ + + batch_p2p_comm: bool = True + """Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if + overlap_p2p_comm is True. + """ + + batch_p2p_sync: bool = True + """When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in + older version of PyTorch. + """ + + use_ring_exchange_p2p: bool = False + """Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires + custom built torch with torch.distributed.ring_exchange. + """ + + deallocate_pipeline_outputs: bool = False + """If True, output data is deallocated after the tensor is sent to the next pipeline stage. + Helps with saving memory, does nothing when pipeline parallel is not used. + """ + + defer_embedding_wgrad_compute: bool = False + """If true, defers the embedding WGRAD GEMMs while pipeline flush is + taking place enabling us to hide pipeline flush latency. Defaults to False. + """ + + wgrad_deferral_limit: int = 0 + """This value tunes the number of micro-batches for which the embedding weight gradient compute + needs to be deferred to pipeline flush, this argument is invalid if + `defer_embedding_wgrad_compute` is False. + Defaults to 0, which means all micro-batches are deferred. + """ + + pipeline_model_parallel_split_rank: Optional[int] = None + """If int, rank where encoder and decoder should be split in cases where the model has both an + encoder and decoder (e.g., T5). Ignored if None. + """ + + overlap_p2p_comm_warmup_flush: bool = False + """If true, overlap communication and computation in warm up and flush phase. + Only valid when overlap_p2p_comm is True and batch_p2p_comm is False. + Defaults to False. + """ + + microbatch_group_size_per_vp_stage: Optional[int] = None + """This value specifies the number of micro-batches that are executed + at a time for a given virtual stage (both forward and backward). + Default (in __post_init__() method below) to pipeline_parallel_size + which specifies a depth-first schedule. + Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2, + num_microbatches = 4, we have + rank 0 | 0 1 0 1 2 3 2 3 + rank 1 | 0 1 0 1 2 3 2 3 + When microbatch_group_size_per_vp_stage=3, num_microbatches = 5, + we have + rank 0 | 0 1 2 0 1 2 3 4 3 4 + rank 1 | 0 1 2 0 1 2 3 4 3 4 + """ + + ################### + # CPU Offloading + ################### + cpu_offloading: bool = False + """When set to True, all the activations are offloaded to the CPU asynchronously.""" + + cpu_offloading_num_layers: int = 0 + """Tells the number of transformer layers for which activations has to be offloaded.""" + + _cpu_offloading_context: Optional[ContextManager] = ( + None + # Used for internal use only, not to be set by a user. + # TODO: Need to move to the 'right' place when possible. + ) + """For internal use only, do not set.""" + + cpu_offloading_activations: bool = True + """If True, offloads the activations to CPU.""" + + cpu_offloading_weights: bool = True + """If True, offloads the weights to CPU.""" + + ################### + # Timing + ################### + barrier_with_L1_time: bool = True + """If true, use barrier with level 1 time measurements. It is up to the user to make sure + calling barrier with their timers will not result in hangs. This can happen if for example + the user adds a level 1 timer that is not called by all ranks. + """ + + def __post_init__(self): + """Python dataclass method that is used to modify attributes after initialization. + See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more + details. + """ + if self.sequence_parallel: + if self.tensor_model_parallel_size <= 1: + raise ValueError("Can not use sequence paralllelism without tensor parallelism") + + if self.expert_tensor_parallel_size is None: + self.expert_tensor_parallel_size = self.tensor_model_parallel_size + + if self.pipeline_model_parallel_size > 1: + if self.pipeline_dtype is None: + raise ValueError( + "When using pipeline parallelism, pipeline_dtype must be specified" + ) + + if self.autocast_dtype is None: + self.autocast_dtype = self.params_dtype + + if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1: + raise ValueError( + "Cannot defer embedding wgrad compute when pipeline model parallel is not used" + ) + + if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion: + raise ValueError( + "Cannot defer embedding wgrad compute when gradient accumulation fusion is not used" + ) + + if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0: + raise ValueError( + "Wgrad deferral limit should be greater than or equal to 0 when it is enabled!" + ) + + if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1: + if self.sequence_parallel is False: + raise ValueError( + "When using expert parallelism and tensor parallelism, " + "sequence parallelism must be used" + ) + + if self.microbatch_group_size_per_vp_stage is None: + self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size + + if self.overlap_p2p_comm_warmup_flush: + if not self.overlap_p2p_comm or self.batch_p2p_comm: + raise ValueError( + "Pipeline parallel communication overlapping in warmup and flush is only " + "compatible with overlap_p2p_comm but not batch_p2p_comm." + ) diff --git a/megatron/core/models/T5/t5_model.py b/megatron/core/models/T5/t5_model.py index 462fbfc6940fc633453719a3193c655d99ad65b7..68335591df92b6527012bf9e43557592962f66ca 100644 --- a/megatron/core/models/T5/t5_model.py +++ b/megatron/core/models/T5/t5_model.py @@ -10,9 +10,11 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.enums import ModelType from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.relative_pos_embedding import RelativePositionEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_block import TransformerBlock @@ -135,9 +137,13 @@ class T5Model(LanguageModule): fp16_lm_cross_entropy: bool = False, parallel_output: bool = True, share_embeddings_and_output_weights: bool = False, - position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', + position_embedding_type: Literal[ + 'learned_absolute', 'rope', 'relative' + ] = 'learned_absolute', rotary_percent: float = 1.0, seq_len_interpolation_factor: Optional[float] = None, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, add_encoder: bool = True, add_decoder: bool = True, ): @@ -193,6 +199,23 @@ class T5Model(LanguageModule): use_cpu_initialization=self.config.use_cpu_initialization, ) + # Relative Position Embeddings + if self.position_embedding_type == 'relative': + self.encoder_relative_pos_emb = RelativePositionEmbedding( + bidirectional=True, + init_method=self.config.init_method, + num_attention_heads=self.config.num_attention_heads, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + ) + self.decoder_relative_pos_emb = RelativePositionEmbedding( + bidirectional=False, + init_method=self.config.init_method, + num_attention_heads=self.config.num_attention_heads, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + ) + # Transformer encoder encoder_spec, decoder_spec = ( self.transformer_encoder_layer_spec, @@ -284,6 +307,27 @@ class T5Model(LanguageModule): ) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + # Relative positional embeddings + encoder_attention_bias_parallel = None + if self.position_embedding_type == 'relative': + query_seq_length = RelativePositionEmbedding.get_relative_seq_len( + inference_params, self.encoder, encoder_input, self.config + ) + key_seq_length = query_seq_length + attention_bias = self.encoder_relative_pos_emb(query_seq_length, key_seq_length) + + # Scatter attention_bias to TP ranks + # First, reshape [1, num_head, seqlen_q, seqlen_kv] to + # [1, seqlen_q, seqlen_kv, num_head] to be scatter along + # the last (num_heads dimension) + attention_bias = torch.permute(attention_bias, (0, 2, 3, 1)) + # Then, scatter to TP region + attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias) + # Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv] + encoder_attention_bias_parallel = torch.permute( + attention_bias_parallel, (0, 3, 1, 2) + ) + # Run encoder. if self.add_encoder: encoder_hidden_states = self.encoder( @@ -291,6 +335,7 @@ class T5Model(LanguageModule): attention_mask=encoder_attn_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, + attention_bias=encoder_attention_bias_parallel, ) else: encoder_hidden_states = self.encoder_hidden_state @@ -315,10 +360,29 @@ class T5Model(LanguageModule): rotary_pos_emb = None if self.position_embedding_type == 'rope': rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.encoder, encoder_input, self.config, packed_seq_params + inference_params, self.decoder, decoder_input, self.config, packed_seq_params ) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + # Relative positional embeddings + decoder_attention_bias_parallel = None + if self.position_embedding_type == 'relative': + query_seq_length = RelativePositionEmbedding.get_relative_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + key_seq_length = query_seq_length + attention_bias = self.decoder_relative_pos_emb(query_seq_length, key_seq_length) + + # Scatter attention_bias to TP ranks + # First, reshape [1, num_head, seqlen_q, seqlen_kv] to + # [1, seqlen_q, seqlen_kv, num_head] to be scatter along + # the last (num_heads dimension) + attention_bias = torch.permute(attention_bias, (0, 2, 3, 1)) + # Then, scatter to TP region + attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias) + # Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv] + decoder_attention_bias_parallel = torch.permute(attention_bias_parallel, (0, 3, 1, 2)) + # Run decoder. decoder_hidden_states = self.decoder( hidden_states=decoder_input, @@ -327,12 +391,15 @@ class T5Model(LanguageModule): context_mask=encoder_decoder_attn_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, + attention_bias=decoder_attention_bias_parallel, ) if self.post_process: - lm_logits = self.lm_head( - decoder_hidden_states, self.shared_embedding_or_output_weight() - ) + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + lm_logits = self.lm_head(decoder_hidden_states, word_embeddings_weight=output_weight) + if lm_labels is None: # [s b h] => [b s h] return lm_logits.transpose(0, 1).contiguous() diff --git a/megatron/core/models/common/embeddings/relative_pos_embedding.py b/megatron/core/models/common/embeddings/relative_pos_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..af17bce1cc5a22d4245daabc102b101d29eb7074 --- /dev/null +++ b/megatron/core/models/common/embeddings/relative_pos_embedding.py @@ -0,0 +1,173 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import logging +import math +from typing import Callable + +import torch +from torch import Tensor, nn + +from megatron.core.inference_params import InferenceParams +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig + +logger = logging.getLogger(__name__) + + +__all__ = ['RelativePositionEmbedding'] + + +class RelativePositionEmbedding(nn.Module): + """Relative Position Embedding for language model. + + Args: + + """ + + def __init__( + self, + bidirectional: bool, + init_method: Callable, + num_attention_heads: int, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + ) -> None: + super().__init__() + + self.bidirectional = bidirectional + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.relative_attention_bias = torch.nn.Embedding( + self.relative_attention_num_buckets, num_attention_heads + ) + init_method(self.relative_attention_bias.weight) + + def _relative_position_bucket( + self, relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from HuggingFace T5 Model: + https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/ + src/transformers/models/t5/modeling_t5.py#L397 + + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, i.e. the + distance in tokens from the attending position to the attended-to position. + If bidirectional=False, then positive relative positions are invalid. We use + smaller buckets for small absolute relative_position and larger buckets for + larger absolute relative_positions. All relative positions >=max_distance map + to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the + model has been trained on. + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, + containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger + # bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def _compute_bias(self, query_length, key_length): + """ + Adapted from HuggingFace T5 Model + https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/ + src/transformers/models/t5/modeling_t5.py#L444C9-L444C21 + + Compute binned relative position bias + + Args: + query_length (int): The length of the query sequence + (e.g., the input sequence in attention). + key_length (int): The length of the key sequence + (e.g., the sequence to compare against in attention). + + Returns: + torch.Tensor: A tensor representing the relative position bias, with shape + (1, num_heads, query_length, key_length). + """ + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + + relative_position = memory_position - context_position # shape(query_length,key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=self.bidirectional, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape(query_length,key_length,num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape(1, num_heads,query_length,key_length) + return values + + @staticmethod + def get_relative_seq_len( + inference_params: InferenceParams, + transformer: TransformerBlock, + transformer_input: Tensor, + transformer_config: TransformerConfig, + ) -> float: + """Function to get the rotary sequence length. + + Args: + inference_params : Used during Inference time + transformer (TransformerBlock): The transformer block (decoder/encoder) used + by the model + transformer_input (Tensor): Input tensor to the transformer + transformer_config (TransformerConfig): Transformer config used by the model + + Returns: + float: The rotary sequence length + """ + if inference_params is not None: + relative_seq_len = inference_params.max_sequence_length + else: + if transformer.input_tensor is not None: + relative_seq_len = transformer.input_tensor.size(0) + else: + relative_seq_len = transformer_input.size(0) + + if transformer_config.sequence_parallel: + relative_seq_len *= transformer_config.tensor_model_parallel_size + + return relative_seq_len + + def forward(self, query_seq_length, key_seq_length): + """ + Args: + Returns: + """ + return self._compute_bias(query_seq_length, key_seq_length) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index c2837c6fa356aa7699cc72882bf1e752a490fa4e..407cc105f1421ca32714d0da128cc147dd49eb09 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -1,213 +1,215 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from megatron.core.transformer.transformer_config import TransformerConfig - from megatron.core.transformer.transformer_block import TransformerBlock - from megatron.core.inference_params import InferenceParams - from megatron.core.packed_seq_params import PackedSeqParams - -import logging -import math -from functools import lru_cache - -import torch -from torch import Tensor, nn - -from megatron.core import parallel_state -from megatron.core.models.common.embeddings.rope_utils import ( # for backward compatibility; pylint: disable=unused-import - _apply_rotary_pos_emb_bshd, - _apply_rotary_pos_emb_thd, - _rotate_half, - apply_rotary_pos_emb, - get_pos_emb_on_this_cp_rank, -) - -logger = logging.getLogger(__name__) - - -__all__ = ['RotaryEmbedding'] - - -class RotaryEmbedding(nn.Module): - """Rotary Embedding for language model. - - Args: - kv_channels (int): Projection weights dimension in multi-head attention. Obtained - from transformer config - rotary_percent (float): Percent of rotary dimension to use for rotary position - embeddings. - rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings. - Defaults to False. - seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE - for longer sequences. The value must be a float larger than 1.0. Defaults to None - rotary_base (int, optional): Base period for rotary position embeddings. Defaults to - 10000. - rope_scaling (bool, optional): Apply rope scaling as used in llama 3.1 - use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly - on the GPU. Defaults to False - """ - - def __init__( - self, - kv_channels: int, - rotary_percent: float, - rotary_interleaved: bool = False, - seq_len_interpolation_factor: float = None, - rotary_base: int = 10000, - rope_scaling: bool = False, - use_cpu_initialization: bool = False, - ) -> None: - super().__init__() - - dim = kv_channels - if rotary_percent < 1.0: - dim = int(dim * rotary_percent) - self.rotary_interleaved = rotary_interleaved - - self.seq_len_interpolation_factor = seq_len_interpolation_factor - device = 'cpu' if use_cpu_initialization else torch.cuda.current_device() - self.inv_freq = 1.0 / ( - rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) - ) - - if rope_scaling: - self.inv_freq = self._apply_scaling(self.inv_freq) - - def _apply_scaling( - self, - freqs, - factor=8, - low_freq_factor=1, - high_freq_factor=4, - original_max_position_embeddings=8192, - ): - # This implementation is adapted from: - # https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343 - - factor = factor # `8` in the original implementation - low_freq_factor = low_freq_factor # `1` in the original implementation - high_freq_factor = high_freq_factor # `4` in the original implementation - old_context_len = original_max_position_embeddings # `8192` in the original implementation - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - wavelen = 2 * math.pi / freqs - # wavelen < high_freq_wavelen: do nothing - # wavelen > low_freq_wavelen: divide by factor - inv_freq_llama = torch.where(wavelen > low_freq_wavelen, freqs / factor, freqs) - # otherwise: interpolate between the two, using a smooth factor - smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - smoothed_inv_freq = ( - 1 - smooth_factor - ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama - is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) - inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - - return inv_freq_llama - - def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor: - """Generates matrix of frequencies based on positions in the sequence, - used to create positional encodings""" - seq = ( - torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - + offset - ) - - if self.seq_len_interpolation_factor is not None: - seq *= 1 / self.seq_len_interpolation_factor - - freqs = torch.outer(seq, self.inv_freq) # [seq len, dim] - - return freqs - - def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor): - """Cosine and sine values for RoPE are precomputed for all positions up to the maximum - sequence length""" - freqs = self.get_freqs_non_repeated(max_seq_len, offset) - cos = torch.cos(freqs) - sin = torch.sin(freqs) - return cos, sin - - @lru_cache(maxsize=32) - def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: - """Forward pass of RoPE embedding. - - Args: - max_seq_len (int): Maximum size of sequence - offset (int, optional): RoPE offset. Defaults to 0. - packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. - - Returns: - Tensor: Embeddings after applying RoPE. - """ - if self.inv_freq.device.type == 'cpu': - # move `inv_freq` to GPU once at the first micro-batch forward pass - self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device()) - - freqs = self.get_freqs_non_repeated(max_seq_len, offset) - # first part even vector components, second part odd vector components, - # 2 * dim in dimension size - if not self.rotary_interleaved: - emb = torch.cat((freqs, freqs), dim=-1) - else: - emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view( - freqs.shape[0], -1 - ) - # emb [seq_length, .., dim] - emb = emb[:, None, None, :] - if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq: - # slice rotary_pos_emb along sequence dimension and select the parition of the current - # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0) - return emb - - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - state_dict.pop(f'{prefix}inv_freq', None) - return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - - def get_rotary_seq_len( - self, - inference_params: InferenceParams, - transformer: TransformerBlock, - transformer_input: Tensor, - transformer_config: TransformerConfig, - packed_seq_params: PackedSeqParams, - ) -> float: - """Function to get the rotary sequence length. - - Args: - inference_params : Used during Inference time - transformer (TransformerBlock): The transformer block (decoder/encoder) used - by the model - transformer_input (Tensor): Input tensor to the transformer - transformer_config (TransformerConfig): Transformer config used by the model - packed_seq_params (PackedSeqParams): Packed sequence params - - Returns: - float: The rotary sequence length - """ - if packed_seq_params is not None: - # max_seqlen are the max sequence length in the packed sequence before being divived - # by the tp and cp size. - return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv) - elif inference_params is not None: - rotary_seq_len = inference_params.max_sequence_length - else: - if transformer.input_tensor is not None: - rotary_seq_len = transformer.input_tensor.size(0) - else: - rotary_seq_len = transformer_input.size(0) - - if transformer_config.sequence_parallel: - rotary_seq_len *= transformer_config.tensor_model_parallel_size - - rotary_seq_len *= transformer_config.context_parallel_size - - return rotary_seq_len +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from megatron.core.transformer.transformer_config import TransformerConfig + from megatron.core.transformer.transformer_block import TransformerBlock + from megatron.core.inference_params import InferenceParams + from megatron.core.packed_seq_params import PackedSeqParams + +import logging +import math +from functools import lru_cache + +import torch +from torch import Tensor, nn + +from megatron.core import parallel_state +from megatron.core.models.common.embeddings.rope_utils import ( # for backward compatibility; pylint: disable=unused-import + _apply_rotary_pos_emb_bshd, + _apply_rotary_pos_emb_thd, + _rotate_half, + apply_rotary_pos_emb, + get_pos_emb_on_this_cp_rank, +) + +logger = logging.getLogger(__name__) + + +__all__ = ['RotaryEmbedding'] + + +class RotaryEmbedding(nn.Module): + """Rotary Embedding for language model. + + Args: + kv_channels (int): Projection weights dimension in multi-head attention. Obtained + from transformer config + rotary_percent (float): Percent of rotary dimension to use for rotary position + embeddings. + rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings. + Defaults to False. + seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE + for longer sequences. The value must be a float larger than 1.0. Defaults to None + rotary_base (int, optional): Base period for rotary position embeddings. Defaults to + 10000. + rope_scaling (bool, optional): Apply rope scaling as used in llama 3.x. + rope_scaling_factor (float, optional): rope scaling factor in llama 3.x. Defaults to 8. + use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly + on the GPU. Defaults to False + """ + + def __init__( + self, + kv_channels: int, + rotary_percent: float, + rotary_interleaved: bool = False, + seq_len_interpolation_factor: float = None, + rotary_base: int = 10000, + rope_scaling: bool = False, + rope_scaling_factor: float = 8.0, + use_cpu_initialization: bool = False, + ) -> None: + super().__init__() + + dim = kv_channels + if rotary_percent < 1.0: + dim = int(dim * rotary_percent) + self.rotary_interleaved = rotary_interleaved + + self.seq_len_interpolation_factor = seq_len_interpolation_factor + device = 'cpu' if use_cpu_initialization else torch.cuda.current_device() + self.inv_freq = 1.0 / ( + rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + if rope_scaling: + self.inv_freq = self._apply_scaling(self.inv_freq, factor=rope_scaling_factor) + + def _apply_scaling( + self, + freqs, + factor=8, + low_freq_factor=1, + high_freq_factor=4, + original_max_position_embeddings=8192, + ): + # This implementation is adapted from: + # https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343 + + factor = factor # `8` in the original implementation + low_freq_factor = low_freq_factor # `1` in the original implementation + high_freq_factor = high_freq_factor # `4` in the original implementation + old_context_len = original_max_position_embeddings # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / freqs + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, freqs / factor, freqs) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama + + def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor: + """Generates matrix of frequencies based on positions in the sequence, + used to create positional encodings""" + seq = ( + torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + + offset + ) + + if self.seq_len_interpolation_factor is not None: + seq *= 1 / self.seq_len_interpolation_factor + + freqs = torch.outer(seq, self.inv_freq) # [seq len, dim] + + return freqs + + def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor): + """Cosine and sine values for RoPE are precomputed for all positions up to the maximum + sequence length""" + freqs = self.get_freqs_non_repeated(max_seq_len, offset) + cos = torch.cos(freqs) + sin = torch.sin(freqs) + return cos, sin + + @lru_cache(maxsize=32) + def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: + """Forward pass of RoPE embedding. + + Args: + max_seq_len (int): Maximum size of sequence + offset (int, optional): RoPE offset. Defaults to 0. + packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. + + Returns: + Tensor: Embeddings after applying RoPE. + """ + if self.inv_freq.device.type == 'cpu': + # move `inv_freq` to GPU once at the first micro-batch forward pass + self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device()) + + freqs = self.get_freqs_non_repeated(max_seq_len, offset) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + if not self.rotary_interleaved: + emb = torch.cat((freqs, freqs), dim=-1) + else: + emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view( + freqs.shape[0], -1 + ) + # emb [seq_length, .., dim] + emb = emb[:, None, None, :] + if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq: + # slice rotary_pos_emb along sequence dimension and select the parition of the current + # CP rank + emb = get_pos_emb_on_this_cp_rank(emb, 0) + return emb + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + state_dict.pop(f'{prefix}inv_freq', None) + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def get_rotary_seq_len( + self, + inference_params: InferenceParams, + transformer: TransformerBlock, + transformer_input: Tensor, + transformer_config: TransformerConfig, + packed_seq_params: PackedSeqParams, + ) -> float: + """Function to get the rotary sequence length. + + Args: + inference_params : Used during Inference time + transformer (TransformerBlock): The transformer block (decoder/encoder) used + by the model + transformer_input (Tensor): Input tensor to the transformer + transformer_config (TransformerConfig): Transformer config used by the model + packed_seq_params (PackedSeqParams): Packed sequence params + + Returns: + float: The rotary sequence length + """ + if packed_seq_params is not None: + # max_seqlen are the max sequence length in the packed sequence before being divived + # by the tp and cp size. + return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv) + elif inference_params is not None: + rotary_seq_len = inference_params.max_sequence_length + else: + if transformer is not None and transformer.input_tensor is not None: + rotary_seq_len = transformer.input_tensor.size(0) + else: + rotary_seq_len = transformer_input.size(0) + + if transformer_config.sequence_parallel: + rotary_seq_len *= transformer_config.tensor_model_parallel_size + + rotary_seq_len *= transformer_config.context_parallel_size + + return rotary_seq_len diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index d0e48c190cacc27a944e9a4bc3a748e3c4570eb7..225626f3f120671de225f1f6df58153bc38088e5 100644 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -1,350 +1,383 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import warnings -from typing import Optional - -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.multi_latent_attention import ( - MLASelfAttention, - MLASelfAttentionSubmodules, -) -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import ( - TransformerBlockSubmodules, - get_num_layers_to_build, -) -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules -from megatron.core.utils import is_te_min_version - -try: - from megatron.core.extensions.transformer_engine import ( - TEColumnParallelLinear, - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TENorm, - TERowParallelLinear, - ) - - HAVE_TE = True -except ImportError: - HAVE_TE = False - -try: - import apex # pylint: disable=unused-import - - from megatron.core.fusions.fused_layer_norm import FusedLayerNorm - - HAVE_APEX = True - LNImpl = FusedLayerNorm -except ImportError: - from megatron.core.transformer.torch_norm import WrappedTorchNorm - - warnings.warn('Apex is not installed. Falling back to Torch Norm') - LNImpl = WrappedTorchNorm - - -def get_gpt_layer_with_transformer_engine_spec( - num_experts: Optional[int] = None, - moe_grouped_gemm: Optional[bool] = False, - qk_layernorm: Optional[bool] = False, - multi_latent_attention: Optional[bool] = False, - fp8: Optional[str] = None, # pylint: disable=unused-arguments - moe_use_legacy_grouped_gemm: Optional[bool] = False, -) -> ModuleSpec: - """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). - - - Args: - num_experts (int, optional): Number of experts. Defaults to None. - moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. - qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. - fp8 (str, optional): Deprecated. For temporary Nemo compatibility. - moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. - Defaults to False. - - Returns: - ModuleSpec: Module specification with TE modules - """ - if fp8 is not None: - warnings.warn( - 'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated' - ' and will be removed soon. Please update your code accordingly.' - ) - - mlp = _get_mlp_module_spec( - use_te=True, - num_experts=num_experts, - moe_grouped_gemm=moe_grouped_gemm, - moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, - ) - - if multi_latent_attention: - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=TENorm, - self_attention=ModuleSpec( - module=MLASelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=MLASelfAttentionSubmodules( - linear_q_proj=TEColumnParallelLinear, - linear_q_down_proj=TEColumnParallelLinear, - linear_q_up_proj=TEColumnParallelLinear, - linear_kv_down_proj=TEColumnParallelLinear, - linear_kv_up_proj=TEColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - q_layernorm=TENorm if qk_layernorm else IdentityOp, - kv_layernorm=TENorm if qk_layernorm else IdentityOp, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=TENorm if num_experts else IdentityOp, - mlp=mlp, - mlp_bda=get_bias_dropout_add, - ), - ) - else: - - # TENorm significantly harms convergence when used - # for QKLayerNorm if TE Version < 1.9; - # we instead use the Apex implementation. - qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm - - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - q_layernorm=qk_norm if qk_layernorm else IdentityOp, - k_layernorm=qk_norm if qk_layernorm else IdentityOp, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=TENorm if num_experts else IdentityOp, - mlp=mlp, - mlp_bda=get_bias_dropout_add, - ), - ) - - -def get_gpt_layer_local_spec( - num_experts: Optional[int] = None, - moe_grouped_gemm: Optional[bool] = False, - qk_layernorm: Optional[bool] = False, - multi_latent_attention: Optional[bool] = False, - fp8: Optional[str] = None, # pylint: disable=unused-arguments - moe_use_legacy_grouped_gemm: Optional[bool] = False, -) -> ModuleSpec: - """Use this spec for an implementation using only modules in Megatron-Core. - - - Args: - num_experts (int, optional): Number of experts. Defaults to None. - moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. - qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. - fp8 (str, optional): Deprecated. For temporary Nemo compatibility. - moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. - Defaults to False. - - Returns: - ModuleSpec: Module specification with Megatron-Core modules - """ - if fp8 is not None: - warnings.warn( - 'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated' - ' and will be removed soon. Please update your code accordingly.' - ) - - mlp = _get_mlp_module_spec( - use_te=False, - num_experts=num_experts, - moe_grouped_gemm=moe_grouped_gemm, - moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, - ) - - if multi_latent_attention: - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=LNImpl, - self_attention=ModuleSpec( - module=MLASelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=MLASelfAttentionSubmodules( - linear_q_proj=ColumnParallelLinear, - linear_q_down_proj=ColumnParallelLinear, - linear_q_up_proj=ColumnParallelLinear, - linear_kv_down_proj=ColumnParallelLinear, - linear_kv_up_proj=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - q_layernorm=LNImpl if qk_layernorm else IdentityOp, - kv_layernorm=LNImpl if qk_layernorm else IdentityOp, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=LNImpl, - mlp=mlp, - mlp_bda=get_bias_dropout_add, - ), - ) - else: - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=LNImpl, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - q_layernorm=LNImpl if qk_layernorm else IdentityOp, - k_layernorm=LNImpl if qk_layernorm else IdentityOp, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=LNImpl, - mlp=mlp, - mlp_bda=get_bias_dropout_add, - sharded_state_dict_keys_map={ - 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', - 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', - }, - ), - ) - - -def _get_mlp_module_spec( - use_te: Optional[bool] = True, - num_experts: Optional[int] = None, - moe_grouped_gemm: Optional[bool] = False, - fp8: Optional[str] = None, # pylint: disable=unused-arguments - moe_use_legacy_grouped_gemm: Optional[bool] = False, -) -> ModuleSpec: - """Helper function to get module spec for MLP/MoE""" - if fp8 is not None: - warnings.warn( - 'The fp8 argument in "_get_mlp_module_spec" has been deprecated' - ' and will be removed soon. Please update your code accordingly.' - ) - - if num_experts is None: - # Dense MLP w/ or w/o TE modules. - return ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear, - linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, - ), - ) - else: - # Mixture of experts with modules in megatron core. - return get_moe_module_spec( - use_te=use_te, - num_experts=num_experts, - moe_grouped_gemm=moe_grouped_gemm, - moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, - ) - - -def get_gpt_decoder_block_spec( - config: TransformerConfig, use_transformer_engine: bool -) -> TransformerBlockSubmodules: - """GPT block spec.""" - if use_transformer_engine: - layer_norm_impl = TENorm - else: - layer_norm_impl = LNImpl - - # Layer specs. - dense_layer_spec = ( - get_gpt_layer_with_transformer_engine_spec( - num_experts=None, - moe_grouped_gemm=False, - qk_layernorm=config.qk_layernorm, - multi_latent_attention=config.multi_latent_attention, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - ) - if use_transformer_engine - else get_gpt_layer_local_spec( - num_experts=None, - moe_grouped_gemm=False, - qk_layernorm=config.qk_layernorm, - multi_latent_attention=config.multi_latent_attention, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - ) - ) - moe_layer_spec = ( - get_gpt_layer_with_transformer_engine_spec( - num_experts=config.num_moe_experts, - moe_grouped_gemm=config.moe_grouped_gemm, - qk_layernorm=config.qk_layernorm, - multi_latent_attention=config.multi_latent_attention, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - ) - if use_transformer_engine - else get_gpt_layer_local_spec( - num_experts=config.num_moe_experts, - moe_grouped_gemm=config.moe_grouped_gemm, - qk_layernorm=config.qk_layernorm, - multi_latent_attention=config.multi_latent_attention, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - ) - ) - - # Parse config.moe_layer_freq to determine the pattern of expert/dense layers. - # 0 stands for dense layers, 1 stands for expert layers. - # For integer N: Creates a pattern with one expert layer every N layers. - # For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense). - if isinstance(config.moe_layer_freq, int): - moe_layer_pattern = [ - 1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers) - ] - elif isinstance(config.moe_layer_freq, list): - moe_layer_pattern = config.moe_layer_freq - assert len(moe_layer_pattern) == config.num_layers, ( - f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, " - f"expected {config.num_layers}, " - f"current moe layer pattern: {config.moe_layer_freq}" - ) - else: - raise ValueError( - f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}" - ) - - # Create the layer specs for the model. - layer_specs = [] - for layer_number in range(config.num_layers): - if moe_layer_pattern[layer_number] == 1: - layer_specs.append(moe_layer_spec) - elif moe_layer_pattern[layer_number] == 0: - layer_specs.append(dense_layer_spec) - else: - raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}") - - # Slice the layer specs to only include the layers that are built in this pipeline stage. - # Note: MCore layer_number starts at 1 - offset = TransformerLayer._get_layer_offset(config) - num_layers_to_build = get_num_layers_to_build(config) - layer_specs = layer_specs[offset : offset + num_layers_to_build] - - # Block spec. - block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl) - - return block_spec +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import warnings +from typing import Optional + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSubmodules, +) +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import ( + TransformerBlockSubmodules, + get_num_layers_to_build, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import ( + TransformerLayer, + TransformerLayerSubmodules, + get_transformer_layer_offset, +) +from megatron.core.utils import is_te_min_version + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import apex # pylint: disable=unused-import + + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + warnings.warn('Apex is not installed. Falling back to Torch Norm') + LNImpl = WrappedTorchNorm + + +def get_gpt_layer_with_transformer_engine_spec( + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + qk_layernorm: Optional[bool] = False, + multi_latent_attention: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + + + Args: + num_experts (int, optional): Number of experts. Defaults to None. + moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. + qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. + fp8 (str, optional): Deprecated. For temporary Nemo compatibility. + moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. + Defaults to False. + + Returns: + ModuleSpec: Module specification with TE modules + """ + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated' + ' and will be removed soon. Please update your code accordingly.' + ) + + mlp = get_mlp_module_spec( + use_te=True, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + if multi_latent_attention: + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TEColumnParallelLinear, + linear_q_up_proj=( + TELayerNormColumnParallelLinear + if qk_layernorm + else TEColumnParallelLinear + ), + linear_kv_down_proj=TEColumnParallelLinear, + linear_kv_up_proj=( + TELayerNormColumnParallelLinear + if qk_layernorm + else TEColumnParallelLinear + ), + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + else: + + # TENorm significantly harms convergence when used + # for QKLayerNorm if TE Version < 1.9; + # we instead use the Apex implementation. + qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=qk_norm if qk_layernorm else IdentityOp, + k_layernorm=qk_norm if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_gpt_layer_local_spec( + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + qk_layernorm: Optional[bool] = False, + multi_latent_attention: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Use this spec for an implementation using only modules in Megatron-Core. + + + Args: + num_experts (int, optional): Number of experts. Defaults to None. + moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. + qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. + fp8 (str, optional): Deprecated. For temporary Nemo compatibility. + moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. + Defaults to False. + + Returns: + ModuleSpec: Module specification with Megatron-Core modules + """ + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated' + ' and will be removed soon. Please update your code accordingly.' + ) + + mlp = get_mlp_module_spec( + use_te=False, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + if multi_latent_attention: + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=ColumnParallelLinear, + linear_q_down_proj=ColumnParallelLinear, + linear_q_up_proj=ColumnParallelLinear, + linear_kv_down_proj=ColumnParallelLinear, + linear_kv_up_proj=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=LNImpl if qk_layernorm else IdentityOp, + kv_layernorm=LNImpl if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + else: + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=LNImpl if qk_layernorm else IdentityOp, + k_layernorm=LNImpl if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map={ + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + }, + ), + ) + + +def _get_mlp_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +): + warnings.warn( + """This private function is on a deprecation track. Please switch to `get_mlp_module_spec` + since it will be removed in a future release.""" + ) + + return get_mlp_module_spec( + use_te=use_te, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + fp8=fp8, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + +def get_mlp_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Helper function to get module spec for MLP/MoE""" + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "_get_mlp_module_spec" has been deprecated' + ' and will be removed soon. Please update your code accordingly.' + ) + + if num_experts is None: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + else: + # Mixture of experts with modules in megatron core. + return get_moe_module_spec( + use_te=use_te, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + +def get_gpt_decoder_block_spec( + config: TransformerConfig, use_transformer_engine: bool +) -> TransformerBlockSubmodules: + """GPT block spec.""" + if use_transformer_engine: + layer_norm_impl = TENorm + else: + layer_norm_impl = LNImpl + + # Layer specs. + dense_layer_spec = ( + get_gpt_layer_with_transformer_engine_spec( + num_experts=None, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + ) + if use_transformer_engine + else get_gpt_layer_local_spec( + num_experts=None, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + ) + ) + moe_layer_spec = ( + get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + ) + if use_transformer_engine + else get_gpt_layer_local_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + ) + ) + + # Parse config.moe_layer_freq to determine the pattern of expert/dense layers. + # 0 stands for dense layers, 1 stands for expert layers. + # For integer N: Creates a pattern with one expert layer every N layers. + # For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense). + if isinstance(config.moe_layer_freq, int): + moe_layer_pattern = [ + 1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers) + ] + elif isinstance(config.moe_layer_freq, list): + moe_layer_pattern = config.moe_layer_freq + assert len(moe_layer_pattern) == config.num_layers, ( + f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, " + f"expected {config.num_layers}, " + f"current moe layer pattern: {config.moe_layer_freq}" + ) + else: + raise ValueError( + f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}" + ) + + # Create the layer specs for the model. + layer_specs = [] + for layer_number in range(config.num_layers): + if moe_layer_pattern[layer_number] == 1: + layer_specs.append(moe_layer_spec) + elif moe_layer_pattern[layer_number] == 0: + layer_specs.append(dense_layer_spec) + else: + raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}") + + # Slice the layer specs to only include the layers that are built in this pipeline stage. + # Note: MCore layer_number starts at 1 + offset = get_transformer_layer_offset(config) + num_layers_to_build = get_num_layers_to_build(config) + layer_specs = layer_specs[offset : offset + num_layers_to_build] + + # Block spec. + block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl) + + return block_spec diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index be8cdce1119df0eee2e639ee881324064515f9c8..8f50cbe85ab13442c411533400522f4e96fc9725 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -1,309 +1,331 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from collections import OrderedDict -from typing import Dict, Literal, Optional - -from torch import Tensor - -from megatron.core import InferenceParams, tensor_parallel -from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk -from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from megatron.core.models.common.language_module.language_module import LanguageModule -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import TransformerBlock -from megatron.core.transformer.transformer_config import TransformerConfig - - -class GPTModel(LanguageModule): - """GPT Transformer language model. - - Args: - config (TransformerConfig): - Transformer config - transformer_layer_spec (ModuleSpec): - Specifies module to use for transformer layers - vocab_size (int): - Vocabulary size - max_sequence_length (int): - maximum size of sequence. This is used for positional embedding - pre_process (bool, optional): - Include embedding layer (used with pipeline parallelism). Defaults to True. - post_process (bool, optional): - Include an output layer (used with pipeline parallelism). Defaults to True. - fp16_lm_cross_entropy (bool, optional): - Defaults to False. - parallel_output (bool, optional): - Do not gather the outputs, keep them split across tensor - parallel ranks. Defaults to True. - share_embeddings_and_output_weights (bool, optional): - When True, input embeddings and output logit weights are shared. Defaults to False. - position_embedding_type (Literal[learned_absolute,rope], optional): - Position embedding type.. Defaults to 'learned_absolute'. - rotary_percent (float, optional): - Percent of rotary dimension to use for rotary position embeddings. - Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. - rotary_base (int, optional): - Base period for rotary position embeddings. Ignored unless - position_embedding_type is 'rope'. - Defaults to 10000. - scatter_embedding_sequence_parallel (bool, optional): - Whether embeddings should be scattered across sequence parallel - region or not. Defaults to True. - seq_len_interpolation_factor (Optional[float], optional): - scale of linearly interpolating RoPE for longer sequences. - The value must be a float larger than 1.0. Defaults to None. - """ - - def __init__( - self, - config: TransformerConfig, - transformer_layer_spec: ModuleSpec, - vocab_size: int, - max_sequence_length: int, - pre_process: bool = True, - post_process: bool = True, - fp16_lm_cross_entropy: bool = False, - parallel_output: bool = True, - share_embeddings_and_output_weights: bool = False, - position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', - rotary_percent: float = 1.0, - rotary_base: int = 10000, - rope_scaling: bool = False, - scatter_embedding_sequence_parallel: bool = True, - seq_len_interpolation_factor: Optional[float] = None, - ) -> None: - super().__init__(config=config) - - if has_config_logger_enabled(config): - log_config_to_disk(config, locals(), prefix=type(self).__name__) - - self.transformer_layer_spec: ModuleSpec = transformer_layer_spec - self.vocab_size = vocab_size - self.max_sequence_length = max_sequence_length - self.pre_process = pre_process - self.post_process = post_process - self.fp16_lm_cross_entropy = fp16_lm_cross_entropy - self.parallel_output = parallel_output - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - self.position_embedding_type = position_embedding_type - - # megatron core pipelining currently depends on model type - # TODO: remove this dependency ? - self.model_type = ModelType.encoder_or_decoder - - # These 4 attributes are needed for TensorRT-LLM export. - self.max_position_embeddings = max_sequence_length - self.rotary_percent = rotary_percent - self.rotary_base = rotary_base - self.rotary_scaling = rope_scaling - - if self.pre_process: - self.embedding = LanguageModelEmbedding( - config=self.config, - vocab_size=self.vocab_size, - max_sequence_length=self.max_sequence_length, - position_embedding_type=position_embedding_type, - scatter_to_sequence_parallel=scatter_embedding_sequence_parallel, - ) - - if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: - self.rotary_pos_emb = RotaryEmbedding( - kv_channels=self.config.kv_channels, - rotary_percent=rotary_percent, - rotary_interleaved=self.config.rotary_interleaved, - seq_len_interpolation_factor=seq_len_interpolation_factor, - rotary_base=rotary_base, - rope_scaling=rope_scaling, - use_cpu_initialization=self.config.use_cpu_initialization, - ) - - # Transformer. - self.decoder = TransformerBlock( - config=self.config, - spec=transformer_layer_spec, - pre_process=self.pre_process, - post_process=self.post_process, - ) - - # Output - if post_process: - if self.config.defer_embedding_wgrad_compute: - # The embedding activation buffer preserves a reference to the input activations - # of the final embedding projection layer GEMM. It will hold the activations for - # all the micro-batches of a global batch for the last pipeline stage. Once we are - # done with all the back props for all the microbatches for the last pipeline stage, - # it will be in the pipeline flush stage. During this pipeline flush we use the - # input activations stored in embedding activation buffer and gradient outputs - # stored in gradient buffer to calculate the weight gradients for the embedding - # final linear layer. - self.embedding_activation_buffer = [] - self.grad_output_buffer = [] - else: - self.embedding_activation_buffer = None - self.grad_output_buffer = None - - self.output_layer = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - self.vocab_size, - config=config, - init_method=config.init_method, - bias=False, - skip_bias_add=False, - gather_output=not self.parallel_output, - skip_weight_param_allocation=self.pre_process - and self.share_embeddings_and_output_weights, - embedding_activation_buffer=self.embedding_activation_buffer, - grad_output_buffer=self.grad_output_buffer, - ) - - if self.pre_process or self.post_process: - self.setup_embeddings_and_output_layer() - - if has_config_logger_enabled(self.config): - log_config_to_disk( - self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt' - ) - - def set_input_tensor(self, input_tensor: Tensor) -> None: - """Sets input tensor to the model. - - See megatron.model.transformer.set_input_tensor() - - Args: - input_tensor (Tensor): Sets the input tensor for the model. - """ - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - - assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' - self.decoder.set_input_tensor(input_tensor[0]) - - def forward( - self, - input_ids: Tensor, - position_ids: Tensor, - attention_mask: Tensor, - decoder_input: Tensor = None, - labels: Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - runtime_gather_output: Optional[bool] = None, - ) -> Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoeder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units - - Args: - runtime_gather_output (bool): Gather output at runtime. Default None means - `parallel_output` arg in the constructor will be used. - """ - # If decoder_input is provided (not None), then input_ids and position_ids are ignored. - # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. - - # Decoder embedding. - if decoder_input is not None: - pass - elif self.pre_process: - decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) - else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor - decoder_input = None - - # Rotary positional embeddings (embedding is None for PP intermediate devices) - rotary_pos_emb = None - rotary_pos_cos = None - rotary_pos_sin = None - if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: - if not self.training and self.config.flash_decode: - # Flash decoding uses precomputed cos and sin for RoPE - rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cos_sin( - inference_params.max_sequence_length - ) - else: - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.decoder, decoder_input, self.config, packed_seq_params - ) - rotary_pos_emb = self.rotary_pos_emb( - rotary_seq_len, - packed_seq=packed_seq_params is not None - and packed_seq_params.qkv_format == 'thd', - ) - - # Run decoder. - hidden_states = self.decoder( - hidden_states=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - if not self.post_process: - return hidden_states - - # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - logits, _ = self.output_layer( - hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output - ) - - if has_config_logger_enabled(self.config): - payload = OrderedDict( - { - 'input_ids': input_ids, - 'position_ids': position_ids, - 'attention_mask': attention_mask, - 'decoder_input': decoder_input, - 'logits': logits, - } - ) - log_config_to_disk(self.config, payload, prefix='input_and_logits') - - if labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() - - loss = self.compute_language_model_loss(labels, logits) - - return loss - - def sharded_state_dict( - self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None - ) -> ShardedStateDict: - """Sharded state dict implementation for GPTModel backward-compatibility - (removing extra state). - - Args: - prefix (str): Module name prefix. - sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. - metadata (Optional[Dict]): metadata controlling sharded state dict creation. - - Returns: - ShardedStateDict: sharded state dict for the GPTModel - """ - sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) - output_layer_extra_state_key = f'{prefix}output_layer._extra_state' - - # Old GPT checkpoints only stored the output layer weight key. So we remove the - # _extra_state key but check that it doesn't contain any data anyway - output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None) - assert not ( - output_extra_state and output_extra_state.data - ), f'Expected output layer extra state to be empty, got: {output_extra_state}' - - return sharded_state_dict +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from collections import OrderedDict +from typing import Dict, Literal, Optional + +import torch +from torch import Tensor + +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig + + +class GPTModel(LanguageModule): + """GPT Transformer language model. + + Args: + config (TransformerConfig): + Transformer config + transformer_layer_spec (ModuleSpec): + Specifies module to use for transformer layers + vocab_size (int): + Vocabulary size + max_sequence_length (int): + maximum size of sequence. This is used for positional embedding + pre_process (bool, optional): + Include embedding layer (used with pipeline parallelism). Defaults to True. + post_process (bool, optional): + Include an output layer (used with pipeline parallelism). Defaults to True. + fp16_lm_cross_entropy (bool, optional): + Defaults to False. + parallel_output (bool, optional): + Do not gather the outputs, keep them split across tensor + parallel ranks. Defaults to True. + share_embeddings_and_output_weights (bool, optional): + When True, input embeddings and output logit weights are shared. Defaults to False. + position_embedding_type (Literal[learned_absolute,rope], optional): + Position embedding type.. Defaults to 'learned_absolute'. + rotary_percent (float, optional): + Percent of rotary dimension to use for rotary position embeddings. + Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. + rotary_base (int, optional): + Base period for rotary position embeddings. Ignored unless + position_embedding_type is 'rope'. + Defaults to 10000. + rope_scaling (bool, optional): Toggle RoPE scaling. + rope_scaling_factor (float): RoPE scaling factor. Default 8. + scatter_embedding_sequence_parallel (bool, optional): + Whether embeddings should be scattered across sequence parallel + region or not. Defaults to True. + seq_len_interpolation_factor (Optional[float], optional): + scale of linearly interpolating RoPE for longer sequences. + The value must be a float larger than 1.0. Defaults to None. + """ + + def __init__( + self, + config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', + rotary_percent: float = 1.0, + rotary_base: int = 10000, + rope_scaling: bool = False, + rope_scaling_factor: float = 8.0, + scatter_embedding_sequence_parallel: bool = True, + seq_len_interpolation_factor: Optional[float] = None, + ) -> None: + super().__init__(config=config) + + if has_config_logger_enabled(config): + log_config_to_disk(config, locals(), prefix=type(self).__name__) + + self.transformer_layer_spec: ModuleSpec = transformer_layer_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + # These 4 attributes are needed for TensorRT-LLM export. + self.max_position_embeddings = max_sequence_length + self.rotary_percent = rotary_percent + self.rotary_base = rotary_base + self.rotary_scaling = rope_scaling + + if self.pre_process: + self.embedding = LanguageModelEmbedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=position_embedding_type, + scatter_to_sequence_parallel=scatter_embedding_sequence_parallel, + ) + + if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + rope_scaling=rope_scaling, + rope_scaling_factor=rope_scaling_factor, + use_cpu_initialization=self.config.use_cpu_initialization, + ) + + # Cache for RoPE tensors which do not change between iterations. + self.rotary_pos_emb_cache = {} + + # Transformer. + self.decoder = TransformerBlock( + config=self.config, + spec=transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + # Output + if post_process: + if self.config.defer_embedding_wgrad_compute: + # The embedding activation buffer preserves a reference to the input activations + # of the final embedding projection layer GEMM. It will hold the activations for + # all the micro-batches of a global batch for the last pipeline stage. Once we are + # done with all the back props for all the microbatches for the last pipeline stage, + # it will be in the pipeline flush stage. During this pipeline flush we use the + # input activations stored in embedding activation buffer and gradient outputs + # stored in gradient buffer to calculate the weight gradients for the embedding + # final linear layer. + self.embedding_activation_buffer = [] + self.grad_output_buffer = [] + else: + self.embedding_activation_buffer = None + self.grad_output_buffer = None + + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.pre_process + and self.share_embeddings_and_output_weights, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, + ) + + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + if has_config_logger_enabled(self.config): + log_config_to_disk( + self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt' + ) + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' + self.decoder.set_input_tensor(input_tensor[0]) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoeder and finally into the post + processing layer (optional). + + It either returns the Loss values if labels are given or the final hidden units + + Args: + runtime_gather_output (bool): Gather output at runtime. Default None means + `parallel_output` arg in the constructor will be used. + """ + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + rotary_pos_cos = None + rotary_pos_sin = None + if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: + if not self.training and self.config.flash_decode and inference_params: + # Flash decoding uses precomputed cos and sin for RoPE + rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( + inference_params.max_sequence_length, + self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), + ) + else: + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config, packed_seq_params + ) + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, + packed_seq=packed_seq_params is not None + and packed_seq_params.qkv_format == 'thd', + ) + if ( + (self.config.enable_cuda_graph or self.config.flash_decode) + and rotary_pos_cos is not None + and inference_params + ): + sequence_len_offset = torch.tensor( + [inference_params.sequence_len_offset] * inference_params.current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **(extra_block_kwargs or {}), + ) + + if not self.post_process: + return hidden_states + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits, _ = self.output_layer( + hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output + ) + + if has_config_logger_enabled(self.config): + payload = OrderedDict( + { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'decoder_input': decoder_input, + 'logits': logits, + } + ) + log_config_to_disk(self.config, payload, prefix='input_and_logits') + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + loss = self.compute_language_model_loss(labels, logits) + + return loss + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Sharded state dict implementation for GPTModel backward-compatibility + (removing extra state). + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. + + Returns: + ShardedStateDict: sharded state dict for the GPTModel + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + output_layer_extra_state_key = f'{prefix}output_layer._extra_state' + + # Old GPT checkpoints only stored the output layer weight key. So we remove the + # _extra_state key but check that it doesn't contain any data anyway + output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None) + assert not ( + output_extra_state and output_extra_state.data + ), f'Expected output layer extra state to be empty, got: {output_extra_state}' + + return sharded_state_dict diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 513eeddc7e3a12824d97fd12b3b66a644c3ecee7..1d53a3bc09758a889f2d8f032b4d756ed00734cc 100644 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -1,81 +1,81 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import warnings -from typing import Optional - -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer.mlp import MLPSubmodules -from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP -from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules -from megatron.core.transformer.moe.shared_experts import SharedExpertMLP -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.utils import get_te_version, is_te_min_version - -try: - from megatron.core.extensions.transformer_engine import ( - TEColumnParallelGroupedLinear, - TEColumnParallelLinear, - TERowParallelGroupedLinear, - TERowParallelLinear, - ) - - HAVE_TE = True -except ImportError: - HAVE_TE = False - - -def get_moe_module_spec( - use_te: Optional[bool] = True, - num_experts: Optional[int] = None, - moe_grouped_gemm: Optional[bool] = False, - moe_use_legacy_grouped_gemm: Optional[bool] = False, -) -> ModuleSpec: - """Helper function to get module spec for MoE""" - assert num_experts is not None - - mlp = MLPSubmodules( - linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, - linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, - ) - - # experts spec - if moe_grouped_gemm: - ## use GroupedMLP - if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm: - ## use TEGroupedLinear - expert_module = TEGroupedMLP - expert_submodule = MLPSubmodules( - linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear - ) - else: - ## use legacy GroupedMLP - expert_module = GroupedMLP - expert_submodule = None - warnings.warn( - 'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. ' - 'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.' - ) - else: - ## use SequentialMLP - expert_module = SequentialMLP - if use_te and not is_te_min_version("1.7.0.dev0"): - warnings.warn( - "Only transformer-engine>=1.7.0 supports MoE experts, " - f"but your version is {get_te_version()}. Use local linear implementation instead." - ) - expert_submodule = MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear - ) - else: - expert_submodule = mlp - - experts = ModuleSpec(module=expert_module, submodules=expert_submodule) - - # shared experts spec - shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp) - - # MoE module spec - moe_module_spec = ModuleSpec( - module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) - ) - return moe_module_spec +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import warnings +from typing import Optional + +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.mlp import MLPSubmodules +from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.utils import get_te_version, is_te_min_version + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, + TEColumnParallelLinear, + TERowParallelGroupedLinear, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + + +def get_moe_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Helper function to get module spec for MoE""" + assert num_experts is not None + + mlp = MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ) + + # experts spec + if moe_grouped_gemm: + ## use GroupedMLP + if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm: + ## use TEGroupedLinear + expert_module = TEGroupedMLP + expert_submodule = MLPSubmodules( + linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear + ) + else: + ## use legacy GroupedMLP + expert_module = GroupedMLP + expert_submodule = None + warnings.warn( + 'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. ' + 'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.' + ) + else: + ## use SequentialMLP + expert_module = SequentialMLP + if use_te and not is_te_min_version("1.7.0.dev0"): + warnings.warn( + "Only transformer-engine>=1.7.0 supports MoE experts, " + f"but your version is {get_te_version()}. Use local linear implementation instead." + ) + expert_submodule = MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ) + else: + expert_submodule = mlp + + experts = ModuleSpec(module=expert_module, submodules=expert_submodule) + + # shared experts spec + shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp) + + # MoE module spec + moe_module_spec = ModuleSpec( + module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) + ) + return moe_module_spec diff --git a/megatron/core/models/huggingface/__init__.py b/megatron/core/models/huggingface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e0b011726005b5b9ffc09e4041ba2f5203bfb9 --- /dev/null +++ b/megatron/core/models/huggingface/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from .module import HuggingFaceModule, build_hf_model diff --git a/megatron/core/models/huggingface/clip_model.py b/megatron/core/models/huggingface/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a1517493d2a9a3d01e480cf3f98872548341fdc8 --- /dev/null +++ b/megatron/core/models/huggingface/clip_model.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from transformers import AutoModel + +from megatron.core.models.huggingface import HuggingFaceModule + + +class ClipHuggingFaceModel(HuggingFaceModule): + """ + Wrapper for CLIP HuggingFace models + """ + + def __init__(self, config): + super().__init__(config) + self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path) + + def forward(self, *args, **kwargs): + """Forward function""" + x = self.model(*args, **kwargs) + x = x['last_hidden_state'] + + return x diff --git a/megatron/core/models/huggingface/module.py b/megatron/core/models/huggingface/module.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1f7e881c28a06503cf4ee95951fb169d16d7be --- /dev/null +++ b/megatron/core/models/huggingface/module.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from transformers import AutoConfig, AutoModel + +from megatron.core.transformer.module import MegatronModule + + +class HuggingFaceModule(MegatronModule): + """ + Basic module for huggingface + """ + + def __init__(self, config): + super().__init__(config=config) + + def set_input_tensor(self, input_tensor): + """Dummy function for set_input_tensor""" + self.input_tensor = input_tensor + + +class AutoHuggingFaceModel(HuggingFaceModule): + """ + Wrapper for HuggingFace AutoModel + """ + + def __init__(self, config): + super().__init__(config) + self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path) + + def forward(self, *args, **kwargs): + """Forward function""" + return self.model(*args, **kwargs) + + +def build_hf_model(config): + """Builds huggingface wrapper model given config""" + hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path) + + if "qwen" in hf_config.model_type: + from megatron.core.models.huggingface.qwen_model import QwenHuggingFaceModel + + model = QwenHuggingFaceModel(config) + elif "vit" in hf_config.model_type: + from megatron.core.models.huggingface.clip_model import ClipHuggingFaceModel + + model = ClipHuggingFaceModel(config) + else: + raise NotImplementedError(f"Huggingface model type {hf_config.model_type} is not supported") + + return model diff --git a/megatron/core/models/huggingface/qwen_model.py b/megatron/core/models/huggingface/qwen_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3a02057a507b43c8829fbf0e95c623ebb0ac069c --- /dev/null +++ b/megatron/core/models/huggingface/qwen_model.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from transformers.models.qwen2 import Qwen2ForCausalLM + +from megatron.core.models.huggingface import HuggingFaceModule + + +class QwenHuggingFaceModel(HuggingFaceModule): + """ + Wrapper for Qwen LM HuggingFace models + """ + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2ForCausalLM.from_pretrained(config.huggingface_model_name_or_path) + + def forward(self, *args, **kwargs): + """Forward function""" + combined_embeddings = kwargs['decoder_input'].permute(1, 0, 2) + x = self.model( + position_ids=None, # TODO: I guess we're just assuming no custom pos ids + attention_mask=kwargs['attention_mask'], + inputs_embeds=combined_embeddings, + labels=kwargs['labels'], + ) + + if kwargs['labels'] is not None: + x = x["loss"] + else: + x = x["logits"] + + return x + + def embedding(self, input_ids, position_ids=None): + """Function to run process tokens with input embeddings""" + return self.model.get_input_embeddings()(input_ids).transpose(1, 0).contiguous() diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index e5fa9efa72c0acd9e301c791e98b1d5a5060d62e..97ddd20004e85dc8ba72467261f394e2b3165e87 100644 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -1,67 +1,67 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.extensions.transformer_engine import ( - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TERowParallelLinear, -) -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules -from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules -from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules - -mamba_stack_spec = ModuleSpec( - module=MambaStack, - submodules=MambaStackSubmodules( - mamba_layer=ModuleSpec( - module=MambaLayer, - submodules=MambaLayerSubmodules( - mixer=ModuleSpec( - module=MambaMixer, - submodules=MambaMixerSubmodules( - in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear - ), - ), - mamba_bda=get_bias_dropout_add, - ), - ), - # Started with spec from gpt_layer_specs.py (with MLP removed) - # Using the TE spec because we had problems getting the non-TE spec - # working - attention_layer=ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - ), - ), - # Started with spec from gpt_layer_specs.py - # Using the TE spec because we had problems getting the non-TE spec - # working - mlp_layer=ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ), - ), -) +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.extensions.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules +from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +mamba_stack_spec = ModuleSpec( + module=MambaStack, + submodules=MambaStackSubmodules( + mamba_layer=ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + mixer=ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear + ), + ), + mamba_bda=get_bias_dropout_add, + ), + ), + # Started with spec from gpt_layer_specs.py (with MLP removed) + # Using the TE spec because we had problems getting the non-TE spec + # working + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), + # Started with spec from gpt_layer_specs.py + # Using the TE spec because we had problems getting the non-TE spec + # working + mlp_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + ), +) diff --git a/megatron/core/models/multimodal/context_parallel.py b/megatron/core/models/multimodal/context_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..8115fca64e4548e20d725b8ee211b314a072fba6 --- /dev/null +++ b/megatron/core/models/multimodal/context_parallel.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Multimodal Sequence Parallel (SP) and Context Parallel (CP) functionality.""" + +import torch + +from megatron.core.packed_seq_params import PackedSeqParams + + +def get_padding( + seq_len, cp_size, tp_size, has_sp, decoder_tp_comm_overlap=False, decoder_seq_len=None +): + """Calculate padding needed for SP and/or CP. + + Args: + seq_len (int): Model sequence length. + cp_size (int): Context parallel size. + tp_size (int): Tensor parallel size. + has_sp (bool): Model uses sequence parallelism. + decoder_tp_comm_overlap (bool): Decoder (LLM) uses tensor parallel communication overlap. + decoder_seq_len (int): Decoder (LLM) maximum sequence length. + + Returns: + padding (int): Padding needed given model configuration. + """ + + padding = 0 + # TP Comm overlap is performed with combined text+image embeddings. + if has_sp and decoder_tp_comm_overlap: + # If TP Comm Overlap is enabled for combined text+image embedding in LM backbone, + # user needs to provide decoder_seq_len with any potential padding needed for SP+CP + assert ( + decoder_seq_len is not None + ), "Please provide decoder seq length when using TP comm overlap for LM backbone" + padding = decoder_seq_len - seq_len + elif has_sp or cp_size > 1: + padding_factor = 1 + if has_sp and cp_size > 1: + # Padding to multiple of tp_size * cp_size * 2 when using CP + SP. + padding_factor = tp_size * cp_size * 2 + elif cp_size > 1: + padding_factor = cp_size * 2 + elif has_sp: + padding_factor = tp_size + + padding = int((seq_len + padding_factor - 1) // padding_factor * padding_factor) - seq_len + + return padding + + +def get_packed_seq_params(tokens, img_seq_len, padding_needed, cp_size, use_packed_sequence=False): + """Get PackedSeqParams for CP. + + Args: + tokens (torch.Tensor): [batch, seq_len] input tokens. + img_seq_len (int): Image sequence length. + padding_needed (int): Padding to add. + cp_size (int): Context parallel size. + use_packed_sequence (bool): Uses sequence packing. + + Returns: + packed_seq_params (PackedSeqParams): Parameters to be sent to Transformer Engine. + """ + batch_size = tokens.shape[0] + # Calculate the valid token seq len that LM backbone should compute on + combined_valid_seqlen = tokens.shape[1] + img_seq_len - padding_needed + cu_seqlens = torch.arange( + 0, + (batch_size + 1) * (combined_valid_seqlen), + step=(combined_valid_seqlen), + dtype=torch.int32, + device=tokens.device, + ) + # Calculate the total padded token seq len + combined_padded_seqlen = tokens.shape[1] + img_seq_len + cu_seqlens_padded = None + qkv_format = 'sbhd' + if cp_size > 1 and (padding_needed > 0 or use_packed_sequence): + # Provide cu_seqlens__padded for CP support + cu_seqlens_padded = torch.arange( + 0, + (batch_size + 1) * (combined_padded_seqlen), + step=(combined_padded_seqlen), + dtype=torch.int32, + device=tokens.device, + ) + # CP with padding mask type requires THD format + qkv_format = 'thd' + + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=combined_padded_seqlen, + max_seqlen_kv=combined_padded_seqlen, + qkv_format=qkv_format, + ) + + return packed_seq_params diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index 3de68b5091719cfd3638cf38ea37ab1963be2fda..09f83ac3343158a5e88159f8d5fd860c3689f1a2 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -1,924 +1,958 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import logging -from collections import namedtuple -from functools import partial -from typing import List, Optional - -import torch - -from megatron.core import InferenceParams, tensor_parallel -from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk -from megatron.core.models.gpt import GPTModel -from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings -from megatron.core.models.vision.multimodal_projector import MultimodalProjector -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.parallel_state import get_context_parallel_group, get_context_parallel_world_size -from megatron.core.transformer import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import log_single_rank - -try: - import transformer_engine # pylint: disable=unused-import - from transformer_engine.pytorch.distributed import gather_along_first_dim - - from megatron.core.extensions.transformer_engine import TEDotProductAttention - from megatron.core.utils import is_te_min_version - - HAVE_TE = True -except: - HAVE_TE = False - if get_context_parallel_world_size() > 1: - raise RuntimeError("ContextParallelism requires TransformerEngine support, but not found.") - - -IGNORE_INDEX = -100 # ID for labels that should be ignored. -# Image token index can be tokenizer dependent so the default value does not work in all cases. -DEFAULT_IMAGE_TOKEN_INDEX = -200 -IMAGE_TOKEN = "" -VIDEO_TOKEN = "